blob: d7ee70a1cd0bfedbd96828a2f9cf2c7d1cb19f55 [file] [log] [blame]
Michalis Spyrou7362f0d2017-10-18 17:58:22 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
Michalis Spyrouf4643372019-11-29 16:17:13 +000025#ifndef ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
26#define ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010027
28#include "arm_compute/core/AccessWindowStatic.h"
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +000029#include "arm_compute/core/utils/misc/Requires.h"
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010030#include "src/core/NEON/NEFixedPoint.h"
31#include "src/core/NEON/wrapper/wrapper.h"
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010032
33#include <arm_neon.h>
34
35namespace arm_compute
36{
37namespace detail
38{
39/** Loads a 3x3 matrix as a row (float).
40 *
Georgios Pinitasf72f9362018-01-12 16:29:45 +000041 * @param[in] ptr Pointer to a float 3x3 matrix.
42 * @param[in] weights_offset (Optional) Weights quantization offset.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010043 *
44 * @return The loaded matrix.
45 */
Georgios Pinitasf72f9362018-01-12 16:29:45 +000046inline float32x4x3_t load_matrix_row(const float *ptr, int weights_offset = 0)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010047{
Georgios Pinitasf72f9362018-01-12 16:29:45 +000048 ARM_COMPUTE_UNUSED(weights_offset);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010049 const float32x4x3_t r =
50 {
51 {
52 vld1q_dup_f32(ptr),
53 vld1q_dup_f32(1 + ptr),
54 vld1q_dup_f32(2 + ptr)
55 }
56 };
57 return r;
58}
59
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +000060/** Loads a 3x3 matrix as a row (uint8_t/int8_t).
Georgios Pinitasf72f9362018-01-12 16:29:45 +000061 *
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +000062 * @param[in] ptr Pointer to a uint8_t/int8_t 3x3 matrix.
Georgios Pinitasf72f9362018-01-12 16:29:45 +000063 * @param[in] weights_offset (Optional) Weights quantization offset.
64 *
65 * @return The loaded matrix.
66 */
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +000067template < typename T, REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
68inline int32x4x3_t load_matrix_row(const T *ptr, int weights_offset = 0)
Georgios Pinitasf72f9362018-01-12 16:29:45 +000069{
70 const int32x4_t v_weights_offset = vdupq_n_s32(weights_offset);
71
72 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
73 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
74 int32x4x3_t r =
75 {
76 {
77 vaddq_s32(v_weights_offset, vdupq_n_s32(*ptr)),
78 vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 1))),
79 vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 2)))
80 }
81 };
82 return r;
83}
84
Georgios Pinitasa26e1662020-03-04 15:31:25 +000085/** Stores a float32x4x2_t array into a memory location.
86 *
87 * @param[in] buffer Pointer to the memory location where the values will be stored.
88 * @param[in] values Values that will be stored.
89 *
90 */
91template <unsigned int stridex>
92void store_results(float *buffer, const float32x4x2_t &values);
93
94template <>
95inline void store_results<1>(float *buffer, const float32x4x2_t &values)
96{
97 vst1q_f32(buffer, values.val[0]);
98 vst1q_f32(buffer + 4, values.val[1]);
99}
100
101template <>
102inline void store_results<2>(float *buffer, const float32x4x2_t &values)
103{
104 vst1q_f32(buffer, values.val[0]);
105}
106
107template <>
108inline void store_results<3>(float *buffer, const float32x4x2_t &values)
109{
110 vst1_f32(buffer, vget_low_f32(values.val[0]));
111}
112
113/** Stores a uint32_t array into a memory location.
114 *
115 * @param[in] buffer Pointer to the memory location where the values will be stored.
116 * @param[in] values Values that will be stored.
117 *
118 */
119template <unsigned int stridex>
120void store_results(int32_t *buffer, const int32x4x2_t &values);
121
122template <>
123inline void store_results<1>(int32_t *buffer, const int32x4x2_t &values)
124{
125 vst1q_s32(buffer, values.val[0]);
126 vst1q_s32(buffer + 4, values.val[1]);
127}
128
129template <>
130inline void store_results<2>(int32_t *buffer, const int32x4x2_t &values)
131{
132 vst1q_s32(buffer, values.val[0]);
133}
134
135template <>
136inline void store_results<3>(int32_t *buffer, const int32x4x2_t &values)
137{
138 vst1_s32(buffer, vget_low_s32(values.val[0]));
139}
140
141template <unsigned int stridex>
142inline void accumulate_results(float *buffer, const float32x4x2_t &values);
143
144template <>
145inline void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
146{
147 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
148 vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
149}
150
151template <>
152inline void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
153{
154 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
155}
156
157template <>
158inline void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
159{
160 vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
161}
162
163template <unsigned int stridex>
164void accumulate_results(int32_t *buffer, const int32x4x2_t &values);
165
166template <>
167inline void accumulate_results<1>(int32_t *buffer, const int32x4x2_t &values)
168{
169 vst1q_s32(buffer, vaddq_s32(vld1q_s32(buffer), values.val[0]));
170 vst1q_s32(buffer + 4, vaddq_s32(vld1q_s32(buffer + 4), values.val[1]));
171}
172
173template <>
174inline void accumulate_results<2>(int32_t *buffer, const int32x4x2_t &values)
175{
176 vst1q_s32(buffer, vaddq_s32(vld1q_s32(buffer), values.val[0]));
177}
178
179template <>
180inline void accumulate_results<3>(int32_t *buffer, const int32x4x2_t &values)
181{
182 vst1_s32(buffer, vadd_s32(vld1_s32(buffer), vget_low_s32(values.val[0])));
183}
184
185#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
186/** Stores a float16x8x2_t array into a memory location.
187 *
188 * @param[in] buffer Pointer to the memory location where the values will be stored.
189 * @param[in] values Values that will be stored.
190 *
191 */
192template <unsigned int stridex>
193void store_results(float16_t *buffer, const float16x8x2_t &values);
194
195template <>
196inline void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
197{
198 vst1q_f16(buffer, values.val[0]);
199 vst1q_f16(buffer + 8, values.val[1]);
200}
201
202template <>
203inline void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
204{
205 vst1q_f16(buffer, values.val[0]);
206}
207
208template <>
209inline void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
210{
211 vst1_f16(buffer, vget_low_f16(values.val[0]));
212}
213
214template <unsigned int stridex>
215inline void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
216
217template <>
218inline void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
219{
220 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
221 vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
222}
223
224template <>
225inline void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
226{
227 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
228}
229
230template <>
231inline void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
232{
233 vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
234}
235#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
236
Usama Arif881f2de2019-04-12 10:29:17 +0100237/** Perform a 3x3 convolution for 4 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
238 *
239 * @param[in] in_top Pointer to the first row of the input.
240 * @param[in] in_mid Pointer to the second row of the input.
241 * @param[in] in_low Pointer to the third row of the input.
242 * @param[in] m0 First row of the filter.
243 * @param[in] m1 Second row of the filter.
244 * @param[in] m2 Third row of the filter.
245 * @param[in] dilation_x Dilation, in elements across x.
246 * @param[in] input_offset (Optional) Input quantization offset.
247 *
248 */
249inline float32x4_t single_convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
250 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
251 const size_t dilation_x, int input_offset)
252{
253 ARM_COMPUTE_UNUSED(input_offset);
254
255 const float32x4x3_t vtop =
256 {
257 {
258 vld1q_f32(in_top),
259 vld1q_f32(in_top + dilation_x),
260 vld1q_f32(in_top + 2 * dilation_x)
261 }
262 };
263 const float32x4x3_t vmid =
264 {
265 {
266 vld1q_f32(in_mid),
267 vld1q_f32(in_mid + dilation_x),
268 vld1q_f32(in_mid + 2 * dilation_x)
269 }
270 };
271 const float32x4x3_t vlow =
272 {
273 {
274 vld1q_f32(in_low),
275 vld1q_f32(in_low + dilation_x),
276 vld1q_f32(in_low + 2 * dilation_x)
277 }
278 };
279 float32x4_t out = vmulq_f32(vtop.val[0], m0.val[0]);
280 out = vmlaq_f32(out, vtop.val[1], m0.val[1]);
281 out = vmlaq_f32(out, vtop.val[2], m0.val[2]);
282
283 out = vmlaq_f32(out, vmid.val[0], m1.val[0]);
284 out = vmlaq_f32(out, vmid.val[1], m1.val[1]);
285 out = vmlaq_f32(out, vmid.val[2], m1.val[2]);
286
287 out = vmlaq_f32(out, vlow.val[0], m2.val[0]);
288 out = vmlaq_f32(out, vlow.val[1], m2.val[1]);
289 out = vmlaq_f32(out, vlow.val[2], m2.val[2]);
290
291 return out;
292}
293
294/** Perform a 3x3 convolution for 8 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
295 *
296 * @param[in] in_top Pointer to the first row of the input.
297 * @param[in] in_mid Pointer to the second row of the input.
298 * @param[in] in_low Pointer to the third row of the input.
299 * @param[in] m0 First row of the filter.
300 * @param[in] m1 Second row of the filter.
301 * @param[in] m2 Third row of the filter.
302 * @param[in] dilation_x Dilation, in elements across x.
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000303 * @param[in] stridex Stride value in elements across x.
Usama Arif881f2de2019-04-12 10:29:17 +0100304 * @param[in] input_offset (Optional) Input quantization offset.
305 *
306 */
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000307inline float32x4x2_t convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
308 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
309 const size_t dilation_x, unsigned int stridex, int input_offset = 0)
Usama Arif881f2de2019-04-12 10:29:17 +0100310{
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000311 ARM_COMPUTE_ERROR_ON(stridex > 3);
312 float32x4x2_t out =
Usama Arif881f2de2019-04-12 10:29:17 +0100313 {
314 {
315 single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
316 single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
317 }
318 };
319
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000320 if(stridex == 2)
321 {
322 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
323 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
324 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
325 }
326 else if(stridex == 3)
327 {
328 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
329 }
Usama Arif881f2de2019-04-12 10:29:17 +0100330
Usama Arif881f2de2019-04-12 10:29:17 +0100331 return out;
332}
333
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100334/** Perform a convolve3x3 on float32.
335 *
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000336 * @param[in] in_top Pointer to the first row of the input.
337 * @param[in] in_mid Pointer to the second row of the input.
338 * @param[in] in_low Pointer to the third row of the input.
339 * @param[out] out_ptr Pointer to the output.
340 * @param[in] m0 First row of the filter.
341 * @param[in] m1 Second row of the filter.
342 * @param[in] m2 Third row of the filter.
343 * @param[in] stridex Stride value in elements across x.
344 * @param[in] input_offset (Optional) Input quantization offset.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100345 *
346 */
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000347template <bool accumulate>
348void convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, float *out_ptr,
349 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
350 unsigned int stridex, int input_offset = 0);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100351
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000352template <bool accumulate>
353inline void convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, float *out_ptr,
354 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
355 unsigned int stridex, int input_offset)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100356{
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000357 ARM_COMPUTE_UNUSED(input_offset);
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000358 ARM_COMPUTE_ERROR_ON(stridex > 3);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000359
alankelly1f103d32019-05-15 23:05:31 +0200360 float32x4x2_t out =
Michalis Spyrouf4643372019-11-29 16:17:13 +0000361 {
alankelly1f103d32019-05-15 23:05:31 +0200362 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000363 vdupq_n_f32(0.f),
364 vdupq_n_f32(0.f)
Michalis Spyrouf4643372019-11-29 16:17:13 +0000365 }
366 };
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000367 if(stridex == 2)
368 {
369 const float32x4x2_t vtop = vld2q_f32(in_top);
370 const float32x4x2_t vmid = vld2q_f32(in_mid);
371 const float32x4x2_t vlow = vld2q_f32(in_low);
372 const float32x4_t vtop_end = vld1q_f32(in_top + 8);
373 const float32x4_t vmid_end = vld1q_f32(in_mid + 8);
374 const float32x4_t vlow_end = vld1q_f32(in_low + 8);
alankelly1f103d32019-05-15 23:05:31 +0200375
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000376 out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
alankelly1f103d32019-05-15 23:05:31 +0200377
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000378 out.val[0] = vmlaq_f32(out.val[0], vtop.val[1], m0.val[1]);
379 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop_end, 1), m0.val[2]);
380
381 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
382 out.val[0] = vmlaq_f32(out.val[0], vmid.val[1], m1.val[1]);
383 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid_end, 1), m1.val[2]);
384
385 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
386 out.val[0] = vmlaq_f32(out.val[0], vlow.val[1], m2.val[1]);
387 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow_end, 1), m2.val[2]);
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000388
389 accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000390 }
391 else
392 {
393 const float32x4x3_t vtop =
394 {
395 {
396 vld1q_f32(in_top),
397 vld1q_f32(in_top + 4),
398 vld1q_f32(in_top + 8)
399 }
400 };
401 const float32x4x3_t vmid =
402 {
403 {
404 vld1q_f32(in_mid),
405 vld1q_f32(in_mid + 4),
406 vld1q_f32(in_mid + 8)
407 }
408 };
409 const float32x4x3_t vlow =
410 {
411 {
412 vld1q_f32(in_low),
413 vld1q_f32(in_low + 4),
414 vld1q_f32(in_low + 8)
415 }
416 };
417 out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
418 out.val[1] = vmulq_f32(vtop.val[1], m0.val[0]);
419
420 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
421 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
422
423 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
424 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
425 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
426
427 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
428 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
429 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
430
431 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
432 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
433
434 out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
435 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
436 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
437
438 out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
439 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
440 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
441
442 if(stridex == 3)
443 {
444 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000445 accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
446 }
447 else
448 {
449 accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000450 }
451 }
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100452}
453
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000454/** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1.
Usama Arif881f2de2019-04-12 10:29:17 +0100455 *
456 * @param[in] in_top Pointer to the first row of the input.
457 * @param[in] in_mid Pointer to the second row of the input.
458 * @param[in] in_low Pointer to the third row of the input.
459 * @param[in] m0 First row of the filter.
460 * @param[in] m1 Second row of the filter.
461 * @param[in] m2 Third row of the filter.
462 * @param[in] dilation_x Dilation, in elements across x.
463 * @param[in] input_offset Input quantization offset.
464 *
465 */
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000466template < typename T, REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
467inline int32x4_t single_convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low,
Usama Arif881f2de2019-04-12 10:29:17 +0100468 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
Michalis Spyrou6f314db2020-01-13 14:07:48 +0000469 size_t dilation_x, int32_t input_offset)
Usama Arif881f2de2019-04-12 10:29:17 +0100470{
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000471 using VectorType = typename std::conditional<std::is_same<T, uint8_t>::value, uint8x8x3_t, int8x8x3_t>::type;
Michalis Spyrou6f314db2020-01-13 14:07:48 +0000472 using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t<int32_t, wrapper::traits::BitWidth::W128>;
Usama Arif881f2de2019-04-12 10:29:17 +0100473
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000474 const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{});
475
476 const VectorType vtop =
Usama Arif881f2de2019-04-12 10:29:17 +0100477 {
478 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000479 wrapper::vload(in_top),
480 wrapper::vload(in_top + dilation_x),
481 wrapper::vload(in_top + 2 * dilation_x)
Usama Arif881f2de2019-04-12 10:29:17 +0100482 }
483 };
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000484 const VectorType vmid =
Usama Arif881f2de2019-04-12 10:29:17 +0100485 {
486 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000487 wrapper::vload(in_mid),
488 wrapper::vload(in_mid + dilation_x),
489 wrapper::vload(in_mid + 2 * dilation_x)
Usama Arif881f2de2019-04-12 10:29:17 +0100490 }
491 };
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000492 const VectorType vlow =
Usama Arif881f2de2019-04-12 10:29:17 +0100493 {
494 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000495 wrapper::vload(in_low),
496 wrapper::vload(in_low + dilation_x),
497 wrapper::vload(in_low + 2 * dilation_x)
Usama Arif881f2de2019-04-12 10:29:17 +0100498 }
499 };
500
501 const int32x4x3_t vtop_s32 =
502 {
503 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000504 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))),
505 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))),
506 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[2])))),
Usama Arif881f2de2019-04-12 10:29:17 +0100507 }
508 };
509 const int32x4x3_t vmid_s32 =
510 {
511 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000512 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))),
513 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))),
514 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[2])))),
Usama Arif881f2de2019-04-12 10:29:17 +0100515 }
516 };
517 const int32x4x3_t vlow_s32 =
518 {
519 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000520 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))),
521 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))),
522 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[2])))),
Usama Arif881f2de2019-04-12 10:29:17 +0100523 }
524 };
525
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000526 int32x4_t out = wrapper::vmul(vtop_s32.val[0], m0.val[0]);
527 out = wrapper::vmla(out, vtop_s32.val[1], m0.val[1]);
528 out = wrapper::vmla(out, vtop_s32.val[2], m0.val[2]);
Usama Arif881f2de2019-04-12 10:29:17 +0100529
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000530 out = wrapper::vmla(out, vmid_s32.val[0], m1.val[0]);
531 out = wrapper::vmla(out, vmid_s32.val[1], m1.val[1]);
532 out = wrapper::vmla(out, vmid_s32.val[2], m1.val[2]);
Usama Arif881f2de2019-04-12 10:29:17 +0100533
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000534 out = wrapper::vmla(out, vlow_s32.val[0], m2.val[0]);
535 out = wrapper::vmla(out, vlow_s32.val[1], m2.val[1]);
536 out = wrapper::vmla(out, vlow_s32.val[2], m2.val[2]);
Usama Arif881f2de2019-04-12 10:29:17 +0100537
538 return out;
539}
540
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000541/** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1.
Usama Arif881f2de2019-04-12 10:29:17 +0100542 *
543 * @param[in] in_top Pointer to the first row of the input.
544 * @param[in] in_mid Pointer to the second row of the input.
545 * @param[in] in_low Pointer to the third row of the input.
546 * @param[in] m0 First row of the filter.
547 * @param[in] m1 Second row of the filter.
548 * @param[in] m2 Third row of the filter.
549 * @param[in] dilation_x Dilation, in elements across x.
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000550 * @param[in] stridex Stride value in elements across x.
Usama Arif881f2de2019-04-12 10:29:17 +0100551 * @param[in] input_offset Input quantization offset.
552 *
553 */
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000554template < typename T, REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
555inline int32x4x2_t convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
556 const size_t dilation_x, unsigned int stridex, int input_offset)
Usama Arif881f2de2019-04-12 10:29:17 +0100557{
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000558 ARM_COMPUTE_ERROR_ON(stridex > 3);
559 int32x4x2_t out =
Usama Arif881f2de2019-04-12 10:29:17 +0100560 {
561 {
562 single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
563 single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
564 }
565 };
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000566
567 if(stridex == 2)
568 {
569 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1);
570 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2);
571 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3);
572 }
573 else if(stridex == 3)
574 {
575 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1);
576 }
Usama Arif881f2de2019-04-12 10:29:17 +0100577 return out;
578}
579
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000580/** Perform a convolve3x3 on 8-bit elements
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000581 *
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000582 * @param[in] in_top Pointer to the first row of the input.
583 * @param[in] in_mid Pointer to the second row of the input.
584 * @param[in] in_low Pointer to the third row of the input.
585 * @param[out] out_ptr Pointer to the output.
586 * @param[in] m0 First row of the filter.
587 * @param[in] m1 Second row of the filter.
588 * @param[in] m2 Third row of the filter.
589 * @param[in] stridex Stride value in elements across x.
590 * @param[in] input_offset Input quantization offset.
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000591 *
592 */
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000593template < bool accumulate, typename T1, typename T2, REQUIRES_TA(std::is_same<T1, uint8_t>::value || std::is_same<T1, int8_t>::value) >
594void convolve_3x3(const T1 *in_top, const T1 *in_mid, const T1 *in_low, T2 *out_ptr,
595 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
596 unsigned int stridex, int32_t input_offset)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000597{
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000598 ARM_COMPUTE_ERROR_ON(stridex > 3);
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000599 using VectorType = typename std::conditional<std::is_same<T1, uint8_t>::value, uint8x8x2_t, int8x8x2_t>::type;
Michalis Spyrou6f314db2020-01-13 14:07:48 +0000600 using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t<int32_t, wrapper::traits::BitWidth::W128>;
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000601
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000602 const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{});
603
604 const VectorType vtop =
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000605 {
606 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000607 wrapper::vload(in_top),
608 wrapper::vload(in_top + 8)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000609 }
610 };
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000611 const VectorType vmid =
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000612 {
613 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000614 wrapper::vload(in_mid),
615 wrapper::vload(in_mid + 8)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000616 }
617 };
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000618 const VectorType vlow =
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000619 {
620 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000621 wrapper::vload(in_low),
622 wrapper::vload(in_low + 8)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000623 }
624 };
625
626 const int32x4x3_t vtop_s32 =
627 {
628 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000629 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))),
630 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vtop.val[0])))),
631 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))),
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000632 }
633 };
634 const int32x4x3_t vmid_s32 =
635 {
636 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000637 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))),
638 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vmid.val[0])))),
639 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))),
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000640 }
641 };
642 const int32x4x3_t vlow_s32 =
643 {
644 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000645 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))),
646 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vlow.val[0])))),
647 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))),
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000648 }
649 };
650
651 int32x4x2_t out
652 {
653 {
Michalis Spyrou6f314db2020-01-13 14:07:48 +0000654 wrapper::vdup_n(static_cast<int32_t>(0), OutputTagType{}),
655 wrapper::vdup_n(static_cast<int32_t>(0), OutputTagType{}),
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000656 }
657 };
658
659 // 0
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000660 out.val[0] = wrapper::vmla(out.val[0], vtop_s32.val[0], m0.val[0]);
661 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vtop_s32.val[0], vtop_s32.val[1]), m0.val[1]);
662 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vtop_s32.val[0], vtop_s32.val[1]), m0.val[2]);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000663
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000664 out.val[0] = wrapper::vmla(out.val[0], vmid_s32.val[0], m1.val[0]);
665 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vmid_s32.val[0], vmid_s32.val[1]), m1.val[1]);
666 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vmid_s32.val[0], vmid_s32.val[1]), m1.val[2]);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000667
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000668 out.val[0] = wrapper::vmla(out.val[0], vlow_s32.val[0], m2.val[0]);
669 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vlow_s32.val[0], vlow_s32.val[1]), m2.val[1]);
670 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vlow_s32.val[0], vlow_s32.val[1]), m2.val[2]);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000671
672 // 1
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000673 out.val[1] = wrapper::vmla(out.val[1], vtop_s32.val[1], m0.val[0]);
674 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vtop_s32.val[1], vtop_s32.val[2]), m0.val[1]);
675 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vtop_s32.val[1], vtop_s32.val[2]), m0.val[2]);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000676
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000677 out.val[1] = wrapper::vmla(out.val[1], vmid_s32.val[1], m1.val[0]);
678 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vmid_s32.val[1], vmid_s32.val[2]), m1.val[1]);
679 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vmid_s32.val[1], vmid_s32.val[2]), m1.val[2]);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000680
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000681 out.val[1] = wrapper::vmla(out.val[1], vlow_s32.val[1], m2.val[0]);
682 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vlow_s32.val[1], vlow_s32.val[2]), m2.val[1]);
683 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vlow_s32.val[1], vlow_s32.val[2]), m2.val[2]);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000684
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000685 if(stridex == 1)
686 {
687 accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
688 }
689 else if(stridex == 2)
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000690 {
691 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1);
692 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2);
693 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3);
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000694
695 accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000696 }
697 else if(stridex == 3)
698 {
699 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1);
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000700 accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000701 }
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000702}
703
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000704#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100705/** Loads a 3x3 matrix as a row (float16_t).
706 *
707 * @param[in] ptr Pointer to a float 3x3 matrix.
708 *
709 * @return The loaded matrix.
710 */
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100711inline float16x8x3_t load_matrix_row(const float16_t *ptr, int weights_offset = 0)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100712{
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100713 ARM_COMPUTE_UNUSED(weights_offset);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100714 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
715 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
716 const float16x8x3_t r =
717 {
718 {
719 vld1q_dup_f16(ptr),
720 vld1q_dup_f16(1 + ptr),
721 vld1q_dup_f16(2 + ptr)
722 }
723 };
724 return r;
725}
726
Usama Arif881f2de2019-04-12 10:29:17 +0100727/** Perform a 3x3 convolution for 8 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
728 *
729 * @param[in] in_top Pointer to the first row of the input.
730 * @param[in] in_mid Pointer to the second row of the input.
731 * @param[in] in_low Pointer to the third row of the input.
732 * @param[in] m0 First row of the filter.
733 * @param[in] m1 Second row of the filter.
734 * @param[in] m2 Third row of the filter.
735 * @param[in] dilation_x Dilation, in elements across x.
736 * @param[in] input_offset (Optional)Input quantization offset.
737 *
738 */
739inline float16x8_t single_convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
740 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
741 const size_t dilation_x, int input_offset = 0)
742{
743 ARM_COMPUTE_UNUSED(input_offset);
744 const float16x8x3_t vtop =
745 {
746 {
747 vld1q_f16(in_top),
748 vld1q_f16(in_top + dilation_x),
749 vld1q_f16(in_top + 2 * dilation_x)
750 }
751 };
752 const float16x8x3_t vmid =
753 {
754 {
755 vld1q_f16(in_mid),
756 vld1q_f16(in_mid + dilation_x),
757 vld1q_f16(in_mid + 2 * dilation_x)
758 }
759 };
760 const float16x8x3_t vlow =
761 {
762 {
763 vld1q_f16(in_low),
764 vld1q_f16(in_low + dilation_x),
765 vld1q_f16(in_low + 2 * dilation_x)
766 }
767 };
768 float16x8_t out = vmulq_f16(vtop.val[0], m0.val[0]);
769 out = vaddq_f16(out, vmulq_f16(vtop.val[1], m0.val[1]));
770 out = vaddq_f16(out, vmulq_f16(vtop.val[2], m0.val[2]));
771
772 out = vaddq_f16(out, vmulq_f16(vmid.val[0], m1.val[0]));
773 out = vaddq_f16(out, vmulq_f16(vmid.val[1], m1.val[1]));
774 out = vaddq_f16(out, vmulq_f16(vmid.val[2], m1.val[2]));
775
776 out = vaddq_f16(out, vmulq_f16(vlow.val[0], m2.val[0]));
777 out = vaddq_f16(out, vmulq_f16(vlow.val[1], m2.val[1]));
778 out = vaddq_f16(out, vmulq_f16(vlow.val[2], m2.val[2]));
779
780 return out;
781}
782
783/** Perform a 3x3 convolution for 16 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
784 *
785 * @param[in] in_top Pointer to the first row of the input.
786 * @param[in] in_mid Pointer to the second row of the input.
787 * @param[in] in_low Pointer to the third row of the input.
788 * @param[in] m0 First row of the filter.
789 * @param[in] m1 Second row of the filter.
790 * @param[in] m2 Third row of the filter.
791 * @param[in] dilation_x Dilation, in elements across x.
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000792 * @param[in] stridex Stride value in elements across x.
793 * @param[in] input_offset (Optional) Input quantization offset.
Usama Arif881f2de2019-04-12 10:29:17 +0100794 *
795 */
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000796inline float16x8x2_t convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
797 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
798 const size_t dilation_x, unsigned int stridex, int input_offset = 0)
Usama Arif881f2de2019-04-12 10:29:17 +0100799{
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000800 float16x8x2_t out =
Usama Arif881f2de2019-04-12 10:29:17 +0100801 {
802 {
803 single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
804 single_convolve_3x3_dilation(in_top + 8, in_mid + 8, in_low + 8, m0, m1, m2, dilation_x, input_offset)
805 }
806 };
Usama Arif881f2de2019-04-12 10:29:17 +0100807
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000808 if(stridex == 2)
809 {
810 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
811 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2);
812 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 3);
813 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 4);
814 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 5);
815 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 4), out.val[0], 6);
816 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 6), out.val[0], 7);
817 }
818 else if(stridex == 3)
819 {
820 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
821 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
822 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
823 }
Usama Arif881f2de2019-04-12 10:29:17 +0100824
Usama Arif881f2de2019-04-12 10:29:17 +0100825 return out;
826}
827
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100828/** Perform a convolve3x3 on float16.
829 *
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000830 * @param[in] in_top Pointer to the first row of the input.
831 * @param[in] in_mid Pointer to the second row of the input.
832 * @param[in] in_low Pointer to the third row of the input.
833 * @param[out] out_ptr Pointer to the output.
834 * @param[in] m0 First row of the filter.
835 * @param[in] m1 Second row of the filter.
836 * @param[in] m2 Third row of the filter.
837 * @param[in] stridex Stride value in elements across x.
838 * @param[in] input_offset (Optional) Input quantization offset.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100839 *
840 */
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000841template <bool accumulate>
842inline void convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, float16_t *out_ptr,
843 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
844 unsigned int stridex, int input_offset = 0)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100845{
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100846 ARM_COMPUTE_UNUSED(input_offset);
alankelly1f103d32019-05-15 23:05:31 +0200847
848 float16x8x2_t out =
Michalis Spyrouf4643372019-11-29 16:17:13 +0000849 {
alankelly1f103d32019-05-15 23:05:31 +0200850 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000851 vdupq_n_f16(0),
Michalis Spyrouf4643372019-11-29 16:17:13 +0000852 vdupq_n_f16(0)
853 }
854 };
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000855 if(stridex == 2)
856 {
857 const float16x8x2_t vtop = vld2q_f16(in_top);
858 const float16x8x2_t vmid = vld2q_f16(in_mid);
859 const float16x8x2_t vlow = vld2q_f16(in_low);
860 const float16x8_t vtop_end = vld1q_f16(in_top + 16);
861 const float16x8_t vmid_end = vld1q_f16(in_mid + 16);
862 const float16x8_t vlow_end = vld1q_f16(in_low + 16);
alankelly1f103d32019-05-15 23:05:31 +0200863
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000864 out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]);
alankelly1f103d32019-05-15 23:05:31 +0200865
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000866 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vtop.val[1], m0.val[1]));
867 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop_end, 1), m0.val[2]));
alankelly1f103d32019-05-15 23:05:31 +0200868
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000869 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
870 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[1], m1.val[1]));
871 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid_end, 1), m1.val[2]));
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100872
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000873 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
874 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[1], m2.val[1]));
875 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow_end, 1), m2.val[2]));
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000876
877 accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000878 }
879 else
880 {
881 const float16x8x3_t vtop =
882 {
883 {
884 vld1q_f16(in_top),
885 vld1q_f16(in_top + 8),
886 vld1q_f16(in_top + 16)
887 }
888 };
889 const float16x8x3_t vmid =
890 {
891 {
892 vld1q_f16(in_mid),
893 vld1q_f16(in_mid + 8),
894 vld1q_f16(in_mid + 16)
895 }
896 };
897 const float16x8x3_t vlow =
898 {
899 {
900 vld1q_f16(in_low),
901 vld1q_f16(in_low + 8),
902 vld1q_f16(in_low + 16)
903 }
904 };
905 out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]);
906 out.val[1] = vmulq_f16(vtop.val[1], m0.val[0]);
907
908 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
909 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
910 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
911 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
912 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
913 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
914 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
915 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
916 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
917 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
918 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
919 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
920 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
921 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
922 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
923 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
924
925 if(stridex == 3)
926 {
927 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
928 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
929 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000930
931 accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
932 }
933 else
934 {
935 accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000936 }
937 }
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100938}
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000939#endif /** __ARM_FEATURE_FP16_VECTOR_ARITHMETIC **/
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100940
941/** Get the number of elements processed on 3x3 convolution.
942 *
943 * @param[in] num_elems_written_per_iteration Number of elements written per iteration on 3x3 convolution.
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000944 * @param[in] stridex Stride value in elements across x.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100945 *
946 * @return The number of elements processed.
947 */
Anthony Barbier15686212017-12-12 17:17:50 +0000948inline int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration, unsigned int stridex)
949{
950 switch(stridex)
951 {
952 case 1:
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000953 return num_elems_written_per_iteration;
Anthony Barbier15686212017-12-12 17:17:50 +0000954 case 2:
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000955 return num_elems_written_per_iteration << 1;
Anthony Barbier15686212017-12-12 17:17:50 +0000956 case 3:
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000957 return num_elems_written_per_iteration * 3;
Anthony Barbier15686212017-12-12 17:17:50 +0000958 default:
959 ARM_COMPUTE_ERROR("stridex not supported");
960 return 0;
961 }
962}
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100963}
964} // namespace arm_compute
Michalis Spyrouf4643372019-11-29 16:17:13 +0000965#endif /* ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H */