blob: f1cbf2af63d856722c2dc897e5c51d0e4e3b7d36 [file] [log] [blame]
Michalis Spyrou7362f0d2017-10-18 17:58:22 +01001/*
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +00002 * Copyright (c) 2017-2020 ARM Limited.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
Michalis Spyrouf4643372019-11-29 16:17:13 +000025#ifndef ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
26#define ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010027
28#include "arm_compute/core/AccessWindowStatic.h"
29#include "arm_compute/core/NEON/NEFixedPoint.h"
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +000030#include "arm_compute/core/NEON/wrapper/wrapper.h"
31#include "arm_compute/core/utils/misc/Requires.h"
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010032
33#include <arm_neon.h>
34
35namespace arm_compute
36{
37namespace detail
38{
39/** Loads a 3x3 matrix as a row (float).
40 *
Georgios Pinitasf72f9362018-01-12 16:29:45 +000041 * @param[in] ptr Pointer to a float 3x3 matrix.
42 * @param[in] weights_offset (Optional) Weights quantization offset.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010043 *
44 * @return The loaded matrix.
45 */
Georgios Pinitasf72f9362018-01-12 16:29:45 +000046inline float32x4x3_t load_matrix_row(const float *ptr, int weights_offset = 0)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010047{
Georgios Pinitasf72f9362018-01-12 16:29:45 +000048 ARM_COMPUTE_UNUSED(weights_offset);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010049 const float32x4x3_t r =
50 {
51 {
52 vld1q_dup_f32(ptr),
53 vld1q_dup_f32(1 + ptr),
54 vld1q_dup_f32(2 + ptr)
55 }
56 };
57 return r;
58}
59
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +000060/** Loads a 3x3 matrix as a row (uint8_t/int8_t).
Georgios Pinitasf72f9362018-01-12 16:29:45 +000061 *
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +000062 * @param[in] ptr Pointer to a uint8_t/int8_t 3x3 matrix.
Georgios Pinitasf72f9362018-01-12 16:29:45 +000063 * @param[in] weights_offset (Optional) Weights quantization offset.
64 *
65 * @return The loaded matrix.
66 */
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +000067template < typename T, REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
68inline int32x4x3_t load_matrix_row(const T *ptr, int weights_offset = 0)
Georgios Pinitasf72f9362018-01-12 16:29:45 +000069{
70 const int32x4_t v_weights_offset = vdupq_n_s32(weights_offset);
71
72 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
73 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
74 int32x4x3_t r =
75 {
76 {
77 vaddq_s32(v_weights_offset, vdupq_n_s32(*ptr)),
78 vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 1))),
79 vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 2)))
80 }
81 };
82 return r;
83}
84
Usama Arif881f2de2019-04-12 10:29:17 +010085/** Perform a 3x3 convolution for 4 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
86 *
87 * @param[in] in_top Pointer to the first row of the input.
88 * @param[in] in_mid Pointer to the second row of the input.
89 * @param[in] in_low Pointer to the third row of the input.
90 * @param[in] m0 First row of the filter.
91 * @param[in] m1 Second row of the filter.
92 * @param[in] m2 Third row of the filter.
93 * @param[in] dilation_x Dilation, in elements across x.
94 * @param[in] input_offset (Optional) Input quantization offset.
95 *
96 */
97inline float32x4_t single_convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
98 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
99 const size_t dilation_x, int input_offset)
100{
101 ARM_COMPUTE_UNUSED(input_offset);
102
103 const float32x4x3_t vtop =
104 {
105 {
106 vld1q_f32(in_top),
107 vld1q_f32(in_top + dilation_x),
108 vld1q_f32(in_top + 2 * dilation_x)
109 }
110 };
111 const float32x4x3_t vmid =
112 {
113 {
114 vld1q_f32(in_mid),
115 vld1q_f32(in_mid + dilation_x),
116 vld1q_f32(in_mid + 2 * dilation_x)
117 }
118 };
119 const float32x4x3_t vlow =
120 {
121 {
122 vld1q_f32(in_low),
123 vld1q_f32(in_low + dilation_x),
124 vld1q_f32(in_low + 2 * dilation_x)
125 }
126 };
127 float32x4_t out = vmulq_f32(vtop.val[0], m0.val[0]);
128 out = vmlaq_f32(out, vtop.val[1], m0.val[1]);
129 out = vmlaq_f32(out, vtop.val[2], m0.val[2]);
130
131 out = vmlaq_f32(out, vmid.val[0], m1.val[0]);
132 out = vmlaq_f32(out, vmid.val[1], m1.val[1]);
133 out = vmlaq_f32(out, vmid.val[2], m1.val[2]);
134
135 out = vmlaq_f32(out, vlow.val[0], m2.val[0]);
136 out = vmlaq_f32(out, vlow.val[1], m2.val[1]);
137 out = vmlaq_f32(out, vlow.val[2], m2.val[2]);
138
139 return out;
140}
141
142/** Perform a 3x3 convolution for 8 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
143 *
144 * @param[in] in_top Pointer to the first row of the input.
145 * @param[in] in_mid Pointer to the second row of the input.
146 * @param[in] in_low Pointer to the third row of the input.
147 * @param[in] m0 First row of the filter.
148 * @param[in] m1 Second row of the filter.
149 * @param[in] m2 Third row of the filter.
150 * @param[in] dilation_x Dilation, in elements across x.
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000151 * @param[in] stridex Stride value in elements across x.
Usama Arif881f2de2019-04-12 10:29:17 +0100152 * @param[in] input_offset (Optional) Input quantization offset.
153 *
154 */
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000155inline float32x4x2_t convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
156 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
157 const size_t dilation_x, unsigned int stridex, int input_offset = 0)
Usama Arif881f2de2019-04-12 10:29:17 +0100158{
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000159 ARM_COMPUTE_ERROR_ON(stridex > 3);
160 float32x4x2_t out =
Usama Arif881f2de2019-04-12 10:29:17 +0100161 {
162 {
163 single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
164 single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
165 }
166 };
167
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000168 if(stridex == 2)
169 {
170 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
171 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
172 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
173 }
174 else if(stridex == 3)
175 {
176 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
177 }
Usama Arif881f2de2019-04-12 10:29:17 +0100178
Usama Arif881f2de2019-04-12 10:29:17 +0100179 return out;
180}
181
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100182/** Perform a convolve3x3 on float32.
183 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100184 * @param[in] in_top Pointer to the first row of the input.
185 * @param[in] in_mid Pointer to the second row of the input.
186 * @param[in] in_low Pointer to the third row of the input.
187 * @param[in] m0 First row of the filter.
188 * @param[in] m1 Second row of the filter.
189 * @param[in] m2 Third row of the filter.
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000190 * @param[in] stridex Stride value in elements across x.
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100191 * @param[in] input_offset (Optional) Input quantization offset.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100192 *
193 */
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000194float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low,
195 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000196 unsigned int stridex, int input_offset = 0);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100197
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000198inline float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low,
199 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
200 unsigned int stridex, int input_offset)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100201{
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000202 ARM_COMPUTE_UNUSED(input_offset);
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000203 ARM_COMPUTE_ERROR_ON(stridex > 3);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000204
alankelly1f103d32019-05-15 23:05:31 +0200205 float32x4x2_t out =
Michalis Spyrouf4643372019-11-29 16:17:13 +0000206 {
alankelly1f103d32019-05-15 23:05:31 +0200207 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000208 vdupq_n_f32(0.f),
209 vdupq_n_f32(0.f)
Michalis Spyrouf4643372019-11-29 16:17:13 +0000210 }
211 };
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000212 if(stridex == 2)
213 {
214 const float32x4x2_t vtop = vld2q_f32(in_top);
215 const float32x4x2_t vmid = vld2q_f32(in_mid);
216 const float32x4x2_t vlow = vld2q_f32(in_low);
217 const float32x4_t vtop_end = vld1q_f32(in_top + 8);
218 const float32x4_t vmid_end = vld1q_f32(in_mid + 8);
219 const float32x4_t vlow_end = vld1q_f32(in_low + 8);
alankelly1f103d32019-05-15 23:05:31 +0200220
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000221 out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
alankelly1f103d32019-05-15 23:05:31 +0200222
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000223 out.val[0] = vmlaq_f32(out.val[0], vtop.val[1], m0.val[1]);
224 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop_end, 1), m0.val[2]);
225
226 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
227 out.val[0] = vmlaq_f32(out.val[0], vmid.val[1], m1.val[1]);
228 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid_end, 1), m1.val[2]);
229
230 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
231 out.val[0] = vmlaq_f32(out.val[0], vlow.val[1], m2.val[1]);
232 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow_end, 1), m2.val[2]);
233 }
234 else
235 {
236 const float32x4x3_t vtop =
237 {
238 {
239 vld1q_f32(in_top),
240 vld1q_f32(in_top + 4),
241 vld1q_f32(in_top + 8)
242 }
243 };
244 const float32x4x3_t vmid =
245 {
246 {
247 vld1q_f32(in_mid),
248 vld1q_f32(in_mid + 4),
249 vld1q_f32(in_mid + 8)
250 }
251 };
252 const float32x4x3_t vlow =
253 {
254 {
255 vld1q_f32(in_low),
256 vld1q_f32(in_low + 4),
257 vld1q_f32(in_low + 8)
258 }
259 };
260 out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
261 out.val[1] = vmulq_f32(vtop.val[1], m0.val[0]);
262
263 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
264 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
265
266 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
267 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
268 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
269
270 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
271 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
272 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
273
274 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
275 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
276
277 out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
278 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
279 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
280
281 out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
282 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
283 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
284
285 if(stridex == 3)
286 {
287 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
288 }
289 }
alankelly1f103d32019-05-15 23:05:31 +0200290
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100291 return out;
292}
293
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000294/** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1.
Usama Arif881f2de2019-04-12 10:29:17 +0100295 *
296 * @param[in] in_top Pointer to the first row of the input.
297 * @param[in] in_mid Pointer to the second row of the input.
298 * @param[in] in_low Pointer to the third row of the input.
299 * @param[in] m0 First row of the filter.
300 * @param[in] m1 Second row of the filter.
301 * @param[in] m2 Third row of the filter.
302 * @param[in] dilation_x Dilation, in elements across x.
303 * @param[in] input_offset Input quantization offset.
304 *
305 */
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000306template < typename T, REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
307inline int32x4_t single_convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low,
Usama Arif881f2de2019-04-12 10:29:17 +0100308 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
Michalis Spyrou6f314db2020-01-13 14:07:48 +0000309 size_t dilation_x, int32_t input_offset)
Usama Arif881f2de2019-04-12 10:29:17 +0100310{
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000311 using VectorType = typename std::conditional<std::is_same<T, uint8_t>::value, uint8x8x3_t, int8x8x3_t>::type;
Michalis Spyrou6f314db2020-01-13 14:07:48 +0000312 using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t<int32_t, wrapper::traits::BitWidth::W128>;
Usama Arif881f2de2019-04-12 10:29:17 +0100313
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000314 const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{});
315
316 const VectorType vtop =
Usama Arif881f2de2019-04-12 10:29:17 +0100317 {
318 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000319 wrapper::vload(in_top),
320 wrapper::vload(in_top + dilation_x),
321 wrapper::vload(in_top + 2 * dilation_x)
Usama Arif881f2de2019-04-12 10:29:17 +0100322 }
323 };
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000324 const VectorType vmid =
Usama Arif881f2de2019-04-12 10:29:17 +0100325 {
326 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000327 wrapper::vload(in_mid),
328 wrapper::vload(in_mid + dilation_x),
329 wrapper::vload(in_mid + 2 * dilation_x)
Usama Arif881f2de2019-04-12 10:29:17 +0100330 }
331 };
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000332 const VectorType vlow =
Usama Arif881f2de2019-04-12 10:29:17 +0100333 {
334 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000335 wrapper::vload(in_low),
336 wrapper::vload(in_low + dilation_x),
337 wrapper::vload(in_low + 2 * dilation_x)
Usama Arif881f2de2019-04-12 10:29:17 +0100338 }
339 };
340
341 const int32x4x3_t vtop_s32 =
342 {
343 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000344 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))),
345 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))),
346 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[2])))),
Usama Arif881f2de2019-04-12 10:29:17 +0100347 }
348 };
349 const int32x4x3_t vmid_s32 =
350 {
351 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000352 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))),
353 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))),
354 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[2])))),
Usama Arif881f2de2019-04-12 10:29:17 +0100355 }
356 };
357 const int32x4x3_t vlow_s32 =
358 {
359 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000360 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))),
361 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))),
362 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[2])))),
Usama Arif881f2de2019-04-12 10:29:17 +0100363 }
364 };
365
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000366 int32x4_t out = wrapper::vmul(vtop_s32.val[0], m0.val[0]);
367 out = wrapper::vmla(out, vtop_s32.val[1], m0.val[1]);
368 out = wrapper::vmla(out, vtop_s32.val[2], m0.val[2]);
Usama Arif881f2de2019-04-12 10:29:17 +0100369
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000370 out = wrapper::vmla(out, vmid_s32.val[0], m1.val[0]);
371 out = wrapper::vmla(out, vmid_s32.val[1], m1.val[1]);
372 out = wrapper::vmla(out, vmid_s32.val[2], m1.val[2]);
Usama Arif881f2de2019-04-12 10:29:17 +0100373
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000374 out = wrapper::vmla(out, vlow_s32.val[0], m2.val[0]);
375 out = wrapper::vmla(out, vlow_s32.val[1], m2.val[1]);
376 out = wrapper::vmla(out, vlow_s32.val[2], m2.val[2]);
Usama Arif881f2de2019-04-12 10:29:17 +0100377
378 return out;
379}
380
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000381/** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1.
Usama Arif881f2de2019-04-12 10:29:17 +0100382 *
383 * @param[in] in_top Pointer to the first row of the input.
384 * @param[in] in_mid Pointer to the second row of the input.
385 * @param[in] in_low Pointer to the third row of the input.
386 * @param[in] m0 First row of the filter.
387 * @param[in] m1 Second row of the filter.
388 * @param[in] m2 Third row of the filter.
389 * @param[in] dilation_x Dilation, in elements across x.
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000390 * @param[in] stridex Stride value in elements across x.
Usama Arif881f2de2019-04-12 10:29:17 +0100391 * @param[in] input_offset Input quantization offset.
392 *
393 */
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000394template < typename T, REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
395inline int32x4x2_t convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
396 const size_t dilation_x, unsigned int stridex, int input_offset)
Usama Arif881f2de2019-04-12 10:29:17 +0100397{
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000398 ARM_COMPUTE_ERROR_ON(stridex > 3);
399 int32x4x2_t out =
Usama Arif881f2de2019-04-12 10:29:17 +0100400 {
401 {
402 single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
403 single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
404 }
405 };
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000406
407 if(stridex == 2)
408 {
409 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1);
410 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2);
411 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3);
412 }
413 else if(stridex == 3)
414 {
415 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1);
416 }
Usama Arif881f2de2019-04-12 10:29:17 +0100417 return out;
418}
419
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000420/** Perform a convolve3x3 on 8-bit elements
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000421 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100422 * @param[in] in_top Pointer to the first row of the input.
423 * @param[in] in_mid Pointer to the second row of the input.
424 * @param[in] in_low Pointer to the third row of the input.
425 * @param[in] m0 First row of the filter.
426 * @param[in] m1 Second row of the filter.
427 * @param[in] m2 Third row of the filter.
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000428 * @param[in] stridex Stride value in elements across x.
429 * @param[in] input_offset Input quantization offset.
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000430 *
431 */
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000432template < typename T, REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
433int32x4x2_t convolve_3x3(const T *in_top, const T *in_mid, const T *in_low,
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000434 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
Michalis Spyrou6f314db2020-01-13 14:07:48 +0000435 unsigned int stridex, int32_t input_offset)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000436{
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000437 ARM_COMPUTE_ERROR_ON(stridex > 3);
438 using VectorType = typename std::conditional<std::is_same<T, uint8_t>::value, uint8x8x2_t, int8x8x2_t>::type;
Michalis Spyrou6f314db2020-01-13 14:07:48 +0000439 using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t<int32_t, wrapper::traits::BitWidth::W128>;
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000440
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000441 const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{});
442
443 const VectorType vtop =
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000444 {
445 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000446 wrapper::vload(in_top),
447 wrapper::vload(in_top + 8)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000448 }
449 };
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000450 const VectorType vmid =
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000451 {
452 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000453 wrapper::vload(in_mid),
454 wrapper::vload(in_mid + 8)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000455 }
456 };
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000457 const VectorType vlow =
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000458 {
459 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000460 wrapper::vload(in_low),
461 wrapper::vload(in_low + 8)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000462 }
463 };
464
465 const int32x4x3_t vtop_s32 =
466 {
467 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000468 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))),
469 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vtop.val[0])))),
470 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))),
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000471 }
472 };
473 const int32x4x3_t vmid_s32 =
474 {
475 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000476 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))),
477 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vmid.val[0])))),
478 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))),
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000479 }
480 };
481 const int32x4x3_t vlow_s32 =
482 {
483 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000484 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))),
485 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vlow.val[0])))),
486 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))),
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000487 }
488 };
489
490 int32x4x2_t out
491 {
492 {
Michalis Spyrou6f314db2020-01-13 14:07:48 +0000493 wrapper::vdup_n(static_cast<int32_t>(0), OutputTagType{}),
494 wrapper::vdup_n(static_cast<int32_t>(0), OutputTagType{}),
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000495 }
496 };
497
498 // 0
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000499 out.val[0] = wrapper::vmla(out.val[0], vtop_s32.val[0], m0.val[0]);
500 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vtop_s32.val[0], vtop_s32.val[1]), m0.val[1]);
501 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vtop_s32.val[0], vtop_s32.val[1]), m0.val[2]);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000502
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000503 out.val[0] = wrapper::vmla(out.val[0], vmid_s32.val[0], m1.val[0]);
504 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vmid_s32.val[0], vmid_s32.val[1]), m1.val[1]);
505 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vmid_s32.val[0], vmid_s32.val[1]), m1.val[2]);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000506
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000507 out.val[0] = wrapper::vmla(out.val[0], vlow_s32.val[0], m2.val[0]);
508 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vlow_s32.val[0], vlow_s32.val[1]), m2.val[1]);
509 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vlow_s32.val[0], vlow_s32.val[1]), m2.val[2]);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000510
511 // 1
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000512 out.val[1] = wrapper::vmla(out.val[1], vtop_s32.val[1], m0.val[0]);
513 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vtop_s32.val[1], vtop_s32.val[2]), m0.val[1]);
514 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vtop_s32.val[1], vtop_s32.val[2]), m0.val[2]);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000515
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000516 out.val[1] = wrapper::vmla(out.val[1], vmid_s32.val[1], m1.val[0]);
517 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vmid_s32.val[1], vmid_s32.val[2]), m1.val[1]);
518 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vmid_s32.val[1], vmid_s32.val[2]), m1.val[2]);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000519
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000520 out.val[1] = wrapper::vmla(out.val[1], vlow_s32.val[1], m2.val[0]);
521 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vlow_s32.val[1], vlow_s32.val[2]), m2.val[1]);
522 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vlow_s32.val[1], vlow_s32.val[2]), m2.val[2]);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000523
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000524 if(stridex == 2)
525 {
526 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1);
527 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2);
528 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3);
529 }
530 else if(stridex == 3)
531 {
532 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1);
533 }
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000534 return out;
535}
536
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100537/** Stores a float32x4x2_t array into a memory location.
538 *
539 * @param[in] buffer Pointer to the memory location where the values will be stored.
540 * @param[in] values Values that will be stored.
541 *
542 */
543template <unsigned int stridex>
544void store_results(float *buffer, const float32x4x2_t &values);
545
546template <>
547inline void store_results<1>(float *buffer, const float32x4x2_t &values)
548{
549 vst1q_f32(buffer, values.val[0]);
550 vst1q_f32(buffer + 4, values.val[1]);
551}
552
553template <>
554inline void store_results<2>(float *buffer, const float32x4x2_t &values)
555{
556 vst1q_f32(buffer, values.val[0]);
557}
558
559template <>
560inline void store_results<3>(float *buffer, const float32x4x2_t &values)
561{
562 vst1_f32(buffer, vget_low_f32(values.val[0]));
563}
564
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000565/** Stores a uint32_t array into a memory location.
566 *
567 * @param[in] buffer Pointer to the memory location where the values will be stored.
568 * @param[in] values Values that will be stored.
569 *
570 */
571template <unsigned int stridex>
572void store_results(int32_t *buffer, const int32x4x2_t &values);
573
574template <>
575inline void store_results<1>(int32_t *buffer, const int32x4x2_t &values)
576{
577 vst1q_s32(buffer, values.val[0]);
578 vst1q_s32(buffer + 4, values.val[1]);
579}
580
581template <>
582inline void store_results<2>(int32_t *buffer, const int32x4x2_t &values)
583{
584 vst1q_s32(buffer, values.val[0]);
585}
586
587template <>
588inline void store_results<3>(int32_t *buffer, const int32x4x2_t &values)
589{
590 vst1_s32(buffer, vget_low_s32(values.val[0]));
591}
592
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000593#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100594/** Loads a 3x3 matrix as a row (float16_t).
595 *
596 * @param[in] ptr Pointer to a float 3x3 matrix.
597 *
598 * @return The loaded matrix.
599 */
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100600inline float16x8x3_t load_matrix_row(const float16_t *ptr, int weights_offset = 0)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100601{
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100602 ARM_COMPUTE_UNUSED(weights_offset);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100603 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
604 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
605 const float16x8x3_t r =
606 {
607 {
608 vld1q_dup_f16(ptr),
609 vld1q_dup_f16(1 + ptr),
610 vld1q_dup_f16(2 + ptr)
611 }
612 };
613 return r;
614}
615
Usama Arif881f2de2019-04-12 10:29:17 +0100616/** Perform a 3x3 convolution for 8 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
617 *
618 * @param[in] in_top Pointer to the first row of the input.
619 * @param[in] in_mid Pointer to the second row of the input.
620 * @param[in] in_low Pointer to the third row of the input.
621 * @param[in] m0 First row of the filter.
622 * @param[in] m1 Second row of the filter.
623 * @param[in] m2 Third row of the filter.
624 * @param[in] dilation_x Dilation, in elements across x.
625 * @param[in] input_offset (Optional)Input quantization offset.
626 *
627 */
628inline float16x8_t single_convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
629 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
630 const size_t dilation_x, int input_offset = 0)
631{
632 ARM_COMPUTE_UNUSED(input_offset);
633 const float16x8x3_t vtop =
634 {
635 {
636 vld1q_f16(in_top),
637 vld1q_f16(in_top + dilation_x),
638 vld1q_f16(in_top + 2 * dilation_x)
639 }
640 };
641 const float16x8x3_t vmid =
642 {
643 {
644 vld1q_f16(in_mid),
645 vld1q_f16(in_mid + dilation_x),
646 vld1q_f16(in_mid + 2 * dilation_x)
647 }
648 };
649 const float16x8x3_t vlow =
650 {
651 {
652 vld1q_f16(in_low),
653 vld1q_f16(in_low + dilation_x),
654 vld1q_f16(in_low + 2 * dilation_x)
655 }
656 };
657 float16x8_t out = vmulq_f16(vtop.val[0], m0.val[0]);
658 out = vaddq_f16(out, vmulq_f16(vtop.val[1], m0.val[1]));
659 out = vaddq_f16(out, vmulq_f16(vtop.val[2], m0.val[2]));
660
661 out = vaddq_f16(out, vmulq_f16(vmid.val[0], m1.val[0]));
662 out = vaddq_f16(out, vmulq_f16(vmid.val[1], m1.val[1]));
663 out = vaddq_f16(out, vmulq_f16(vmid.val[2], m1.val[2]));
664
665 out = vaddq_f16(out, vmulq_f16(vlow.val[0], m2.val[0]));
666 out = vaddq_f16(out, vmulq_f16(vlow.val[1], m2.val[1]));
667 out = vaddq_f16(out, vmulq_f16(vlow.val[2], m2.val[2]));
668
669 return out;
670}
671
672/** Perform a 3x3 convolution for 16 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
673 *
674 * @param[in] in_top Pointer to the first row of the input.
675 * @param[in] in_mid Pointer to the second row of the input.
676 * @param[in] in_low Pointer to the third row of the input.
677 * @param[in] m0 First row of the filter.
678 * @param[in] m1 Second row of the filter.
679 * @param[in] m2 Third row of the filter.
680 * @param[in] dilation_x Dilation, in elements across x.
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000681 * @param[in] stridex Stride value in elements across x.
682 * @param[in] input_offset (Optional) Input quantization offset.
Usama Arif881f2de2019-04-12 10:29:17 +0100683 *
684 */
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000685inline float16x8x2_t convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
686 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
687 const size_t dilation_x, unsigned int stridex, int input_offset = 0)
Usama Arif881f2de2019-04-12 10:29:17 +0100688{
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000689 float16x8x2_t out =
Usama Arif881f2de2019-04-12 10:29:17 +0100690 {
691 {
692 single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
693 single_convolve_3x3_dilation(in_top + 8, in_mid + 8, in_low + 8, m0, m1, m2, dilation_x, input_offset)
694 }
695 };
Usama Arif881f2de2019-04-12 10:29:17 +0100696
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000697 if(stridex == 2)
698 {
699 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
700 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2);
701 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 3);
702 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 4);
703 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 5);
704 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 4), out.val[0], 6);
705 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 6), out.val[0], 7);
706 }
707 else if(stridex == 3)
708 {
709 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
710 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
711 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
712 }
Usama Arif881f2de2019-04-12 10:29:17 +0100713
Usama Arif881f2de2019-04-12 10:29:17 +0100714 return out;
715}
716
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100717/** Perform a convolve3x3 on float16.
718 *
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000719 * @param[in] in_top Pointer to the first row of the input.
720 * @param[in] in_mid Pointer to the second row of the input.
721 * @param[in] in_low Pointer to the third row of the input.
722 * @param[in] m0 First row of the filter.
723 * @param[in] m1 Second row of the filter.
724 * @param[in] m2 Third row of the filter.
725 * @param[in] stridex Stride value in elements across x.
726 * @param[in] input_offset (Optional) Input quantization offset.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100727 *
728 */
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000729inline float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
730 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
731 unsigned int stridex, int input_offset = 0)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100732{
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100733 ARM_COMPUTE_UNUSED(input_offset);
alankelly1f103d32019-05-15 23:05:31 +0200734
735 float16x8x2_t out =
Michalis Spyrouf4643372019-11-29 16:17:13 +0000736 {
alankelly1f103d32019-05-15 23:05:31 +0200737 {
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000738 vdupq_n_f16(0),
Michalis Spyrouf4643372019-11-29 16:17:13 +0000739 vdupq_n_f16(0)
740 }
741 };
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000742 if(stridex == 2)
743 {
744 const float16x8x2_t vtop = vld2q_f16(in_top);
745 const float16x8x2_t vmid = vld2q_f16(in_mid);
746 const float16x8x2_t vlow = vld2q_f16(in_low);
747 const float16x8_t vtop_end = vld1q_f16(in_top + 16);
748 const float16x8_t vmid_end = vld1q_f16(in_mid + 16);
749 const float16x8_t vlow_end = vld1q_f16(in_low + 16);
alankelly1f103d32019-05-15 23:05:31 +0200750
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000751 out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]);
alankelly1f103d32019-05-15 23:05:31 +0200752
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000753 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vtop.val[1], m0.val[1]));
754 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop_end, 1), m0.val[2]));
alankelly1f103d32019-05-15 23:05:31 +0200755
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000756 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
757 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[1], m1.val[1]));
758 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid_end, 1), m1.val[2]));
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100759
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000760 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
761 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[1], m2.val[1]));
762 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow_end, 1), m2.val[2]));
763 }
764 else
765 {
766 const float16x8x3_t vtop =
767 {
768 {
769 vld1q_f16(in_top),
770 vld1q_f16(in_top + 8),
771 vld1q_f16(in_top + 16)
772 }
773 };
774 const float16x8x3_t vmid =
775 {
776 {
777 vld1q_f16(in_mid),
778 vld1q_f16(in_mid + 8),
779 vld1q_f16(in_mid + 16)
780 }
781 };
782 const float16x8x3_t vlow =
783 {
784 {
785 vld1q_f16(in_low),
786 vld1q_f16(in_low + 8),
787 vld1q_f16(in_low + 16)
788 }
789 };
790 out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]);
791 out.val[1] = vmulq_f16(vtop.val[1], m0.val[0]);
792
793 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
794 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
795 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
796 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
797 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
798 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
799 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
800 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
801 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
802 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
803 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
804 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
805 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
806 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
807 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
808 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
809
810 if(stridex == 3)
811 {
812 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
813 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
814 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
815 }
816 }
817
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100818 return out;
819}
820
821/** Stores a float16x8x2_t array into a memory location.
822 *
823 * @param[in] buffer Pointer to the memory location where the values will be stored.
824 * @param[in] values Values that will be stored.
825 *
826 */
827template <unsigned int stridex>
828void store_results(float16_t *buffer, const float16x8x2_t &values);
829
830template <>
831inline void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
832{
833 vst1q_f16(buffer, values.val[0]);
834 vst1q_f16(buffer + 8, values.val[1]);
835}
836
837template <>
838inline void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
839{
840 vst1q_f16(buffer, values.val[0]);
841}
842
843template <>
844inline void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
845{
846 vst1_f16(buffer, vget_low_f16(values.val[0]));
847}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000848#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100849
850/** Get the number of elements processed on 3x3 convolution.
851 *
852 * @param[in] num_elems_written_per_iteration Number of elements written per iteration on 3x3 convolution.
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000853 * @param[in] stridex Stride value in elements across x.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100854 *
855 * @return The number of elements processed.
856 */
Anthony Barbier15686212017-12-12 17:17:50 +0000857inline int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration, unsigned int stridex)
858{
859 switch(stridex)
860 {
861 case 1:
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000862 return num_elems_written_per_iteration;
Anthony Barbier15686212017-12-12 17:17:50 +0000863 case 2:
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000864 return num_elems_written_per_iteration << 1;
Anthony Barbier15686212017-12-12 17:17:50 +0000865 case 3:
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000866 return num_elems_written_per_iteration * 3;
Anthony Barbier15686212017-12-12 17:17:50 +0000867 default:
868 ARM_COMPUTE_ERROR("stridex not supported");
869 return 0;
870 }
871}
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100872}
873} // namespace arm_compute
Michalis Spyrouf4643372019-11-29 16:17:13 +0000874#endif /* ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H */