blob: 3547d2d110bf8e3002d4e0e21edc31936458e27a [file] [log] [blame]
Michalis Spyrou7362f0d2017-10-18 17:58:22 +01001/*
Usama Arif881f2de2019-04-12 10:29:17 +01002 * Copyright (c) 2017-2019 ARM Limited.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#ifndef __ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H__
26#define __ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H__
27
28#include "arm_compute/core/AccessWindowStatic.h"
29#include "arm_compute/core/NEON/NEFixedPoint.h"
30
31#include <arm_neon.h>
32
33namespace arm_compute
34{
35namespace detail
36{
37/** Loads a 3x3 matrix as a row (float).
38 *
Georgios Pinitasf72f9362018-01-12 16:29:45 +000039 * @param[in] ptr Pointer to a float 3x3 matrix.
40 * @param[in] weights_offset (Optional) Weights quantization offset.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010041 *
42 * @return The loaded matrix.
43 */
Georgios Pinitasf72f9362018-01-12 16:29:45 +000044inline float32x4x3_t load_matrix_row(const float *ptr, int weights_offset = 0)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010045{
Georgios Pinitasf72f9362018-01-12 16:29:45 +000046 ARM_COMPUTE_UNUSED(weights_offset);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010047 const float32x4x3_t r =
48 {
49 {
50 vld1q_dup_f32(ptr),
51 vld1q_dup_f32(1 + ptr),
52 vld1q_dup_f32(2 + ptr)
53 }
54 };
55 return r;
56}
57
Georgios Pinitasf72f9362018-01-12 16:29:45 +000058/** Loads a 3x3 matrix as a row (uint8_t).
59 *
60 * @param[in] ptr Pointer to a uint8_t 3x3 matrix.
61 * @param[in] weights_offset (Optional) Weights quantization offset.
62 *
63 * @return The loaded matrix.
64 */
65inline int32x4x3_t load_matrix_row(const uint8_t *ptr, int weights_offset = 0)
66{
67 const int32x4_t v_weights_offset = vdupq_n_s32(weights_offset);
68
69 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
70 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
71 int32x4x3_t r =
72 {
73 {
74 vaddq_s32(v_weights_offset, vdupq_n_s32(*ptr)),
75 vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 1))),
76 vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 2)))
77 }
78 };
79 return r;
80}
81
Usama Arif881f2de2019-04-12 10:29:17 +010082/** Perform a 3x3 convolution for 4 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
83 *
84 * @param[in] in_top Pointer to the first row of the input.
85 * @param[in] in_mid Pointer to the second row of the input.
86 * @param[in] in_low Pointer to the third row of the input.
87 * @param[in] m0 First row of the filter.
88 * @param[in] m1 Second row of the filter.
89 * @param[in] m2 Third row of the filter.
90 * @param[in] dilation_x Dilation, in elements across x.
91 * @param[in] input_offset (Optional) Input quantization offset.
92 *
93 */
94inline float32x4_t single_convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
95 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
96 const size_t dilation_x, int input_offset)
97{
98 ARM_COMPUTE_UNUSED(input_offset);
99
100 const float32x4x3_t vtop =
101 {
102 {
103 vld1q_f32(in_top),
104 vld1q_f32(in_top + dilation_x),
105 vld1q_f32(in_top + 2 * dilation_x)
106 }
107 };
108 const float32x4x3_t vmid =
109 {
110 {
111 vld1q_f32(in_mid),
112 vld1q_f32(in_mid + dilation_x),
113 vld1q_f32(in_mid + 2 * dilation_x)
114 }
115 };
116 const float32x4x3_t vlow =
117 {
118 {
119 vld1q_f32(in_low),
120 vld1q_f32(in_low + dilation_x),
121 vld1q_f32(in_low + 2 * dilation_x)
122 }
123 };
124 float32x4_t out = vmulq_f32(vtop.val[0], m0.val[0]);
125 out = vmlaq_f32(out, vtop.val[1], m0.val[1]);
126 out = vmlaq_f32(out, vtop.val[2], m0.val[2]);
127
128 out = vmlaq_f32(out, vmid.val[0], m1.val[0]);
129 out = vmlaq_f32(out, vmid.val[1], m1.val[1]);
130 out = vmlaq_f32(out, vmid.val[2], m1.val[2]);
131
132 out = vmlaq_f32(out, vlow.val[0], m2.val[0]);
133 out = vmlaq_f32(out, vlow.val[1], m2.val[1]);
134 out = vmlaq_f32(out, vlow.val[2], m2.val[2]);
135
136 return out;
137}
138
139/** Perform a 3x3 convolution for 8 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
140 *
141 * @param[in] in_top Pointer to the first row of the input.
142 * @param[in] in_mid Pointer to the second row of the input.
143 * @param[in] in_low Pointer to the third row of the input.
144 * @param[in] m0 First row of the filter.
145 * @param[in] m1 Second row of the filter.
146 * @param[in] m2 Third row of the filter.
147 * @param[in] dilation_x Dilation, in elements across x.
148 * @param[in] input_offset (Optional) Input quantization offset.
149 *
150 */
151template <unsigned int stridex>
152float32x4x2_t convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
153 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
154 const size_t dilation_x, int input_offset = 0);
155
156template <>
157inline float32x4x2_t convolve_3x3_dilation<1>(const float *in_top, const float *in_mid, const float *in_low,
158 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
159 const size_t dilation_x, int input_offset)
160{
161 ARM_COMPUTE_UNUSED(input_offset);
162
163 const float32x4x2_t out =
164 {
165 {
166 single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
167 single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
168 }
169 };
170
171 return out;
172}
173
174template <>
175inline float32x4x2_t convolve_3x3_dilation<2>(const float *in_top, const float *in_mid, const float *in_low,
176 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
177 const size_t dilation_x, int input_offset)
178{
179 ARM_COMPUTE_UNUSED(input_offset);
180
181 float32x4x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset);
182 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
183 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
184 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
185 return out;
186}
187
188template <>
189inline float32x4x2_t convolve_3x3_dilation<3>(const float *in_top, const float *in_mid, const float *in_low,
190 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
191 const size_t dilation_x, int input_offset)
192{
193 ARM_COMPUTE_UNUSED(input_offset);
194
195 float32x4x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset);
196 ;
197 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
198 return out;
199}
200
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100201/** Perform a convolve3x3 on float32.
202 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100203 * @param[in] in_top Pointer to the first row of the input.
204 * @param[in] in_mid Pointer to the second row of the input.
205 * @param[in] in_low Pointer to the third row of the input.
206 * @param[in] m0 First row of the filter.
207 * @param[in] m1 Second row of the filter.
208 * @param[in] m2 Third row of the filter.
209 * @param[in] input_offset (Optional) Input quantization offset.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100210 *
211 */
212template <unsigned int stridex>
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000213float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low,
214 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100215 int input_offset = 0);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100216
217template <>
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000218inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, const float *in_low,
219 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100220 int input_offset)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100221{
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000222 ARM_COMPUTE_UNUSED(input_offset);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100223
224 const float32x4x3_t vtop =
225 {
226 {
227 vld1q_f32(in_top),
228 vld1q_f32(in_top + 4),
229 vld1q_f32(in_top + 8)
230 }
231 };
232 const float32x4x3_t vmid =
233 {
234 {
235 vld1q_f32(in_mid),
236 vld1q_f32(in_mid + 4),
237 vld1q_f32(in_mid + 8)
238 }
239 };
240 const float32x4x3_t vlow =
241 {
242 {
243 vld1q_f32(in_low),
244 vld1q_f32(in_low + 4),
245 vld1q_f32(in_low + 8)
246 }
247 };
248 float32x4x2_t out =
249 {
250 {
251 vmulq_f32(vtop.val[0], m0.val[0]),
252 vmulq_f32(vtop.val[1], m0.val[0])
253 }
254 };
255 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
256 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
257
258 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
259 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
260 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
261
262 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
263 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
264 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
265
266 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
267 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
268
269 out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
270 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
271 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
272
273 out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
274 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
275 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
276 return out;
277}
278
279template <>
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000280inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, const float *in_low,
281 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100282 int input_offset)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100283{
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000284 ARM_COMPUTE_UNUSED(input_offset);
285
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100286 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, input_offset);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100287 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
288 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
289 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
290 return out;
291}
292
293template <>
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000294inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, const float *in_low,
295 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100296 int input_offset)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100297{
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000298 ARM_COMPUTE_UNUSED(input_offset);
299
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100300 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, input_offset);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100301 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
302 return out;
303}
304
Usama Arif881f2de2019-04-12 10:29:17 +0100305/** Perform a 3x3 convolution for 4 consecutive elements on uint8_t when dilation.x() or dilation.y() is not 1.
306 *
307 * @param[in] in_top Pointer to the first row of the input.
308 * @param[in] in_mid Pointer to the second row of the input.
309 * @param[in] in_low Pointer to the third row of the input.
310 * @param[in] m0 First row of the filter.
311 * @param[in] m1 Second row of the filter.
312 * @param[in] m2 Third row of the filter.
313 * @param[in] dilation_x Dilation, in elements across x.
314 * @param[in] input_offset Input quantization offset.
315 *
316 */
317inline int32x4_t single_convolve_3x3_dilation(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low,
318 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
319 size_t dilation_x, int input_offset)
320{
321 const int32x4_t v_input_offset = vdupq_n_s32(input_offset);
322
323 const uint8x8x3_t vtop =
324 {
325 {
326 vld1_u8(in_top),
327 vld1_u8(in_top + dilation_x),
328 vld1_u8(in_top + 2 * dilation_x)
329 }
330 };
331 const uint8x8x3_t vmid =
332 {
333 {
334 vld1_u8(in_mid),
335 vld1_u8(in_mid + dilation_x),
336 vld1_u8(in_mid + 2 * dilation_x)
337 }
338 };
339 const uint8x8x3_t vlow =
340 {
341 {
342 vld1_u8(in_low),
343 vld1_u8(in_low + dilation_x),
344 vld1_u8(in_low + 2 * dilation_x)
345 }
346 };
347
348 const int32x4x3_t vtop_s32 =
349 {
350 {
351 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[0])))), //convert from uint8x8 to uint16x8, to uint16x4(lower or bottom half) to int16x4 to int32x4
352 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[1])))),
353 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[2])))),
354 }
355 };
356 const int32x4x3_t vmid_s32 =
357 {
358 {
359 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[0])))),
360 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[1])))),
361 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[2])))),
362 }
363 };
364 const int32x4x3_t vlow_s32 =
365 {
366 {
367 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[0])))),
368 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[1])))),
369 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[2])))),
370 }
371 };
372
373 int32x4_t out = vmulq_s32(vtop_s32.val[0], m0.val[0]);
374 out = vmlaq_s32(out, vtop_s32.val[1], m0.val[1]);
375 out = vmlaq_s32(out, vtop_s32.val[2], m0.val[2]);
376
377 out = vmlaq_s32(out, vmid_s32.val[0], m1.val[0]);
378 out = vmlaq_s32(out, vmid_s32.val[1], m1.val[1]);
379 out = vmlaq_s32(out, vmid_s32.val[2], m1.val[2]);
380
381 out = vmlaq_s32(out, vlow_s32.val[0], m2.val[0]);
382 out = vmlaq_s32(out, vlow_s32.val[1], m2.val[1]);
383 out = vmlaq_s32(out, vlow_s32.val[2], m2.val[2]);
384
385 return out;
386}
387
388/** Perform a 3x3 convolution for 4 consecutive elements on uint8_t when dilation.x() or dilation.y() is not 1.
389 *
390 * @param[in] in_top Pointer to the first row of the input.
391 * @param[in] in_mid Pointer to the second row of the input.
392 * @param[in] in_low Pointer to the third row of the input.
393 * @param[in] m0 First row of the filter.
394 * @param[in] m1 Second row of the filter.
395 * @param[in] m2 Third row of the filter.
396 * @param[in] dilation_x Dilation, in elements across x.
397 * @param[in] input_offset Input quantization offset.
398 *
399 */
400template <unsigned int stridex>
401int32x4x2_t convolve_3x3_dilation(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low,
402 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
403 const size_t dilation_x, int input_offset);
404
405template <>
406inline int32x4x2_t convolve_3x3_dilation<1>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
407 const size_t dilation_x, int input_offset)
408{
409 const int32x4x2_t out =
410 {
411 {
412 single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
413 single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
414 }
415 };
416 return out;
417}
418
419template <>
420inline int32x4x2_t convolve_3x3_dilation<2>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low,
421 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
422 const size_t dilation_x, int input_offset)
423{
424 int32x4x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset);
425
426 out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 2), out.val[0], 1);
427 out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 0), out.val[0], 2);
428 out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 2), out.val[0], 3);
429 return out;
430}
431
432template <>
433inline int32x4x2_t convolve_3x3_dilation<3>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low,
434 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
435 const size_t dilation_x, int input_offset)
436{
437 int32x4x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset);
438 out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 3), out.val[0], 1);
439 return out;
440}
441
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000442/** Perform a convolve3x3 on uint8_t
443 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100444 * @param[in] in_top Pointer to the first row of the input.
445 * @param[in] in_mid Pointer to the second row of the input.
446 * @param[in] in_low Pointer to the third row of the input.
447 * @param[in] m0 First row of the filter.
448 * @param[in] m1 Second row of the filter.
449 * @param[in] m2 Third row of the filter.
450 * @param[in] input_offset (Optional) Input quantization offset.
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000451 *
452 */
453template <unsigned int stridex>
454int32x4x2_t convolve_3x3(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low,
455 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100456 int input_offset);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000457
458template <>
459inline int32x4x2_t convolve_3x3<1>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100460 int input_offset)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000461{
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000462 const int32x4_t v_input_offset = vdupq_n_s32(input_offset);
463
464 const uint8x8x2_t vtop =
465 {
466 {
467 vld1_u8(in_top),
468 vld1_u8(in_top + 8)
469 }
470 };
471 const uint8x8x2_t vmid =
472 {
473 {
474 vld1_u8(in_mid),
475 vld1_u8(in_mid + 8)
476 }
477 };
478 const uint8x8x2_t vlow =
479 {
480 {
481 vld1_u8(in_low),
482 vld1_u8(in_low + 8)
483 }
484 };
485
486 const int32x4x3_t vtop_s32 =
487 {
488 {
489 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[0])))),
490 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vtop.val[0])))),
491 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[1])))),
492 }
493 };
494 const int32x4x3_t vmid_s32 =
495 {
496 {
497 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[0])))),
498 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vmid.val[0])))),
499 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[1])))),
500 }
501 };
502 const int32x4x3_t vlow_s32 =
503 {
504 {
505 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[0])))),
506 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vlow.val[0])))),
507 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[1])))),
508 }
509 };
510
511 int32x4x2_t out
512 {
513 {
514 vdupq_n_s32(0),
515 vdupq_n_s32(0),
516 }
517 };
518
519 // 0
520 out.val[0] = vmlaq_s32(out.val[0], vtop_s32.val[0], m0.val[0]);
521 out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vtop_s32.val[0], vtop_s32.val[1], 1), m0.val[1]);
522 out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vtop_s32.val[0], vtop_s32.val[1], 2), m0.val[2]);
523
524 out.val[0] = vmlaq_s32(out.val[0], vmid_s32.val[0], m1.val[0]);
525 out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vmid_s32.val[0], vmid_s32.val[1], 1), m1.val[1]);
526 out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vmid_s32.val[0], vmid_s32.val[1], 2), m1.val[2]);
527
528 out.val[0] = vmlaq_s32(out.val[0], vlow_s32.val[0], m2.val[0]);
529 out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vlow_s32.val[0], vlow_s32.val[1], 1), m2.val[1]);
530 out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vlow_s32.val[0], vlow_s32.val[1], 2), m2.val[2]);
531
532 // 1
533 out.val[1] = vmlaq_s32(out.val[1], vtop_s32.val[1], m0.val[0]);
534 out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vtop_s32.val[1], vtop_s32.val[2], 1), m0.val[1]);
535 out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vtop_s32.val[1], vtop_s32.val[2], 2), m0.val[2]);
536
537 out.val[1] = vmlaq_s32(out.val[1], vmid_s32.val[1], m1.val[0]);
538 out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vmid_s32.val[1], vmid_s32.val[2], 1), m1.val[1]);
539 out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vmid_s32.val[1], vmid_s32.val[2], 2), m1.val[2]);
540
541 out.val[1] = vmlaq_s32(out.val[1], vlow_s32.val[1], m2.val[0]);
542 out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vlow_s32.val[1], vlow_s32.val[2], 1), m2.val[1]);
543 out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vlow_s32.val[1], vlow_s32.val[2], 2), m2.val[2]);
544
545 return out;
546}
547
548template <>
549inline int32x4x2_t convolve_3x3<2>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low,
550 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100551 int input_offset)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000552{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100553 int32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, input_offset);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000554 out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 2), out.val[0], 1);
555 out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 0), out.val[0], 2);
556 out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 2), out.val[0], 3);
557 return out;
558}
559
560template <>
561inline int32x4x2_t convolve_3x3<3>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low,
562 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100563 int input_offset)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000564{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100565 int32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, input_offset);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000566 out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 3), out.val[0], 1);
567 return out;
568}
569
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100570/** Stores a float32x4x2_t array into a memory location.
571 *
572 * @param[in] buffer Pointer to the memory location where the values will be stored.
573 * @param[in] values Values that will be stored.
574 *
575 */
576template <unsigned int stridex>
577void store_results(float *buffer, const float32x4x2_t &values);
578
579template <>
580inline void store_results<1>(float *buffer, const float32x4x2_t &values)
581{
582 vst1q_f32(buffer, values.val[0]);
583 vst1q_f32(buffer + 4, values.val[1]);
584}
585
586template <>
587inline void store_results<2>(float *buffer, const float32x4x2_t &values)
588{
589 vst1q_f32(buffer, values.val[0]);
590}
591
592template <>
593inline void store_results<3>(float *buffer, const float32x4x2_t &values)
594{
595 vst1_f32(buffer, vget_low_f32(values.val[0]));
596}
597
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000598/** Stores a uint32_t array into a memory location.
599 *
600 * @param[in] buffer Pointer to the memory location where the values will be stored.
601 * @param[in] values Values that will be stored.
602 *
603 */
604template <unsigned int stridex>
605void store_results(int32_t *buffer, const int32x4x2_t &values);
606
607template <>
608inline void store_results<1>(int32_t *buffer, const int32x4x2_t &values)
609{
610 vst1q_s32(buffer, values.val[0]);
611 vst1q_s32(buffer + 4, values.val[1]);
612}
613
614template <>
615inline void store_results<2>(int32_t *buffer, const int32x4x2_t &values)
616{
617 vst1q_s32(buffer, values.val[0]);
618}
619
620template <>
621inline void store_results<3>(int32_t *buffer, const int32x4x2_t &values)
622{
623 vst1_s32(buffer, vget_low_s32(values.val[0]));
624}
625
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000626#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100627/** Loads a 3x3 matrix as a row (float16_t).
628 *
629 * @param[in] ptr Pointer to a float 3x3 matrix.
630 *
631 * @return The loaded matrix.
632 */
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100633inline float16x8x3_t load_matrix_row(const float16_t *ptr, int weights_offset = 0)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100634{
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100635 ARM_COMPUTE_UNUSED(weights_offset);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100636 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
637 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
638 const float16x8x3_t r =
639 {
640 {
641 vld1q_dup_f16(ptr),
642 vld1q_dup_f16(1 + ptr),
643 vld1q_dup_f16(2 + ptr)
644 }
645 };
646 return r;
647}
648
Usama Arif881f2de2019-04-12 10:29:17 +0100649/** Perform a 3x3 convolution for 8 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
650 *
651 * @param[in] in_top Pointer to the first row of the input.
652 * @param[in] in_mid Pointer to the second row of the input.
653 * @param[in] in_low Pointer to the third row of the input.
654 * @param[in] m0 First row of the filter.
655 * @param[in] m1 Second row of the filter.
656 * @param[in] m2 Third row of the filter.
657 * @param[in] dilation_x Dilation, in elements across x.
658 * @param[in] input_offset (Optional)Input quantization offset.
659 *
660 */
661inline float16x8_t single_convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
662 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
663 const size_t dilation_x, int input_offset = 0)
664{
665 ARM_COMPUTE_UNUSED(input_offset);
666 const float16x8x3_t vtop =
667 {
668 {
669 vld1q_f16(in_top),
670 vld1q_f16(in_top + dilation_x),
671 vld1q_f16(in_top + 2 * dilation_x)
672 }
673 };
674 const float16x8x3_t vmid =
675 {
676 {
677 vld1q_f16(in_mid),
678 vld1q_f16(in_mid + dilation_x),
679 vld1q_f16(in_mid + 2 * dilation_x)
680 }
681 };
682 const float16x8x3_t vlow =
683 {
684 {
685 vld1q_f16(in_low),
686 vld1q_f16(in_low + dilation_x),
687 vld1q_f16(in_low + 2 * dilation_x)
688 }
689 };
690 float16x8_t out = vmulq_f16(vtop.val[0], m0.val[0]);
691 out = vaddq_f16(out, vmulq_f16(vtop.val[1], m0.val[1]));
692 out = vaddq_f16(out, vmulq_f16(vtop.val[2], m0.val[2]));
693
694 out = vaddq_f16(out, vmulq_f16(vmid.val[0], m1.val[0]));
695 out = vaddq_f16(out, vmulq_f16(vmid.val[1], m1.val[1]));
696 out = vaddq_f16(out, vmulq_f16(vmid.val[2], m1.val[2]));
697
698 out = vaddq_f16(out, vmulq_f16(vlow.val[0], m2.val[0]));
699 out = vaddq_f16(out, vmulq_f16(vlow.val[1], m2.val[1]));
700 out = vaddq_f16(out, vmulq_f16(vlow.val[2], m2.val[2]));
701
702 return out;
703}
704
705/** Perform a 3x3 convolution for 16 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
706 *
707 * @param[in] in_top Pointer to the first row of the input.
708 * @param[in] in_mid Pointer to the second row of the input.
709 * @param[in] in_low Pointer to the third row of the input.
710 * @param[in] m0 First row of the filter.
711 * @param[in] m1 Second row of the filter.
712 * @param[in] m2 Third row of the filter.
713 * @param[in] dilation_x Dilation, in elements across x.
714 * @param[in] input_offset (Optional)Input quantization offset.
715 *
716 */
717template <unsigned int stridex>
718float16x8x2_t convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
719 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
720 const size_t dilation_x, int input_offset = 0);
721
722template <>
723inline float16x8x2_t convolve_3x3_dilation<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
724 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
725 const size_t dilation_x, int input_offset)
726{
727 const float16x8x2_t out =
728 {
729 {
730 single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
731 single_convolve_3x3_dilation(in_top + 8, in_mid + 8, in_low + 8, m0, m1, m2, dilation_x, input_offset)
732 }
733 };
734 return out;
735}
736
737template <>
738inline float16x8x2_t convolve_3x3_dilation<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
739 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
740 const size_t dilation_x, int input_offset)
741{
742 ARM_COMPUTE_UNUSED(input_offset);
743 float16x8x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset);
744 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
745 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2);
746 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 3);
747 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 4);
748 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 5);
749 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 4), out.val[0], 6);
750 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 6), out.val[0], 7);
751 return out;
752}
753
754template <>
755inline float16x8x2_t convolve_3x3_dilation<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
756 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
757 const size_t dilation_x, int input_offset)
758{
759 ARM_COMPUTE_UNUSED(input_offset);
760 float16x8x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset);
761 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
762 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
763 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
764 return out;
765}
766
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100767/** Perform a convolve3x3 on float16.
768 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100769 * @param[in] in_top Pointer to the first row of the input.
770 * @param[in] in_mid Pointer to the second row of the input.
771 * @param[in] in_low Pointer to the third row of the input.
772 * @param[in] m0 First row of the filter.
773 * @param[in] m1 Second row of the filter.
774 * @param[in] m2 Third row of the filter.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100775 *
776 */
777template <unsigned int stridex>
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100778float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
779 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
780 int input_offset = 0);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100781
782template <>
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100783inline float16x8x2_t convolve_3x3<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
784 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
785 int input_offset)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100786{
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100787 ARM_COMPUTE_UNUSED(input_offset);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100788 const float16x8x3_t vtop =
789 {
790 {
791 vld1q_f16(in_top),
792 vld1q_f16(in_top + 8),
793 vld1q_f16(in_top + 16)
794 }
795 };
796 const float16x8x3_t vmid =
797 {
798 {
799 vld1q_f16(in_mid),
800 vld1q_f16(in_mid + 8),
801 vld1q_f16(in_mid + 16)
802 }
803 };
804 const float16x8x3_t vlow =
805 {
806 {
807 vld1q_f16(in_low),
808 vld1q_f16(in_low + 8),
809 vld1q_f16(in_low + 16)
810 }
811 };
812 float16x8x2_t out =
813 {
814 {
815 vmulq_f16(vtop.val[0], m0.val[0]),
816 vmulq_f16(vtop.val[1], m0.val[0])
817 }
818 };
819 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
820 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
821 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
822 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
823 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
824 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
825 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
826 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
827 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
828 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
829 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
830 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
831 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
832 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
833 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
834 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
835 return out;
836}
837
838template <>
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100839inline float16x8x2_t convolve_3x3<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
840 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
841 int input_offset)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100842{
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100843 ARM_COMPUTE_UNUSED(input_offset);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100844 float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100845 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
Gian Marco Iodice41acb762018-08-23 10:25:06 +0100846 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2);
847 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 3);
848 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 4);
849 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 5);
850 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 4), out.val[0], 6);
851 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 6), out.val[0], 7);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100852 return out;
853}
854
855template <>
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100856inline float16x8x2_t convolve_3x3<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
857 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
858 int input_offset)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100859{
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100860 ARM_COMPUTE_UNUSED(input_offset);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100861 float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100862 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
Gian Marco Iodice41acb762018-08-23 10:25:06 +0100863 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
864 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100865 return out;
866}
867
868/** Stores a float16x8x2_t array into a memory location.
869 *
870 * @param[in] buffer Pointer to the memory location where the values will be stored.
871 * @param[in] values Values that will be stored.
872 *
873 */
874template <unsigned int stridex>
875void store_results(float16_t *buffer, const float16x8x2_t &values);
876
877template <>
878inline void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
879{
880 vst1q_f16(buffer, values.val[0]);
881 vst1q_f16(buffer + 8, values.val[1]);
882}
883
884template <>
885inline void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
886{
887 vst1q_f16(buffer, values.val[0]);
888}
889
890template <>
891inline void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
892{
893 vst1_f16(buffer, vget_low_f16(values.val[0]));
894}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000895#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100896
897/** Get the number of elements processed on 3x3 convolution.
898 *
899 * @param[in] num_elems_written_per_iteration Number of elements written per iteration on 3x3 convolution.
900 *
901 * @return The number of elements processed.
902 */
903template <unsigned int stridex>
904int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration);
905
906template <>
907inline int get_input_num_elems_processed<1>(unsigned int num_elems_written_per_iteration)
908{
909 return num_elems_written_per_iteration;
910}
911
912template <>
913inline int get_input_num_elems_processed<2>(unsigned int num_elems_written_per_iteration)
914{
915 return num_elems_written_per_iteration << 1;
916}
917
918template <>
919inline int get_input_num_elems_processed<3>(unsigned int num_elems_written_per_iteration)
920{
921 return num_elems_written_per_iteration * 3;
922}
Anthony Barbier15686212017-12-12 17:17:50 +0000923inline int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration, unsigned int stridex)
924{
925 switch(stridex)
926 {
927 case 1:
928 return get_input_num_elems_processed<1>(num_elems_written_per_iteration);
929 case 2:
930 return get_input_num_elems_processed<2>(num_elems_written_per_iteration);
931 case 3:
932 return get_input_num_elems_processed<3>(num_elems_written_per_iteration);
933 default:
934 ARM_COMPUTE_ERROR("stridex not supported");
935 return 0;
936 }
937}
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100938}
939} // namespace arm_compute
Anthony Barbier15686212017-12-12 17:17:50 +0000940#endif /* __ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H__ */