blob: d56fd44700e2a2ca33633fc936c567a7bf5aac47 [file] [log] [blame]
Michalis Spyrou7362f0d2017-10-18 17:58:22 +01001/*
Georgios Pinitasf72f9362018-01-12 16:29:45 +00002 * Copyright (c) 2017-2018 ARM Limited.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#ifndef __ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H__
26#define __ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H__
27
28#include "arm_compute/core/AccessWindowStatic.h"
29#include "arm_compute/core/NEON/NEFixedPoint.h"
30
31#include <arm_neon.h>
32
33namespace arm_compute
34{
35namespace detail
36{
37/** Loads a 3x3 matrix as a row (float).
38 *
Georgios Pinitasf72f9362018-01-12 16:29:45 +000039 * @param[in] ptr Pointer to a float 3x3 matrix.
40 * @param[in] weights_offset (Optional) Weights quantization offset.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010041 *
42 * @return The loaded matrix.
43 */
Georgios Pinitasf72f9362018-01-12 16:29:45 +000044inline float32x4x3_t load_matrix_row(const float *ptr, int weights_offset = 0)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010045{
Georgios Pinitasf72f9362018-01-12 16:29:45 +000046 ARM_COMPUTE_UNUSED(weights_offset);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010047 const float32x4x3_t r =
48 {
49 {
50 vld1q_dup_f32(ptr),
51 vld1q_dup_f32(1 + ptr),
52 vld1q_dup_f32(2 + ptr)
53 }
54 };
55 return r;
56}
57
Georgios Pinitasf72f9362018-01-12 16:29:45 +000058/** Loads a 3x3 matrix as a row (uint8_t).
59 *
60 * @param[in] ptr Pointer to a uint8_t 3x3 matrix.
61 * @param[in] weights_offset (Optional) Weights quantization offset.
62 *
63 * @return The loaded matrix.
64 */
65inline int32x4x3_t load_matrix_row(const uint8_t *ptr, int weights_offset = 0)
66{
67 const int32x4_t v_weights_offset = vdupq_n_s32(weights_offset);
68
69 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
70 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
71 int32x4x3_t r =
72 {
73 {
74 vaddq_s32(v_weights_offset, vdupq_n_s32(*ptr)),
75 vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 1))),
76 vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 2)))
77 }
78 };
79 return r;
80}
81
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010082/** Perform a convolve3x3 on float32.
83 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010084 * @param[in] in_top Pointer to the first row of the input.
85 * @param[in] in_mid Pointer to the second row of the input.
86 * @param[in] in_low Pointer to the third row of the input.
87 * @param[in] m0 First row of the filter.
88 * @param[in] m1 Second row of the filter.
89 * @param[in] m2 Third row of the filter.
90 * @param[in] input_offset (Optional) Input quantization offset.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010091 *
92 */
93template <unsigned int stridex>
Georgios Pinitasf72f9362018-01-12 16:29:45 +000094float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low,
95 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010096 int input_offset = 0);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010097
98template <>
Georgios Pinitasf72f9362018-01-12 16:29:45 +000099inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, const float *in_low,
100 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100101 int input_offset)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100102{
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000103 ARM_COMPUTE_UNUSED(input_offset);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100104
105 const float32x4x3_t vtop =
106 {
107 {
108 vld1q_f32(in_top),
109 vld1q_f32(in_top + 4),
110 vld1q_f32(in_top + 8)
111 }
112 };
113 const float32x4x3_t vmid =
114 {
115 {
116 vld1q_f32(in_mid),
117 vld1q_f32(in_mid + 4),
118 vld1q_f32(in_mid + 8)
119 }
120 };
121 const float32x4x3_t vlow =
122 {
123 {
124 vld1q_f32(in_low),
125 vld1q_f32(in_low + 4),
126 vld1q_f32(in_low + 8)
127 }
128 };
129 float32x4x2_t out =
130 {
131 {
132 vmulq_f32(vtop.val[0], m0.val[0]),
133 vmulq_f32(vtop.val[1], m0.val[0])
134 }
135 };
136 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
137 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
138
139 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
140 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
141 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
142
143 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
144 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
145 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
146
147 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
148 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
149
150 out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
151 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
152 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
153
154 out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
155 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
156 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
157 return out;
158}
159
160template <>
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000161inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, const float *in_low,
162 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100163 int input_offset)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100164{
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000165 ARM_COMPUTE_UNUSED(input_offset);
166
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100167 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, input_offset);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100168 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
169 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
170 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
171 return out;
172}
173
174template <>
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000175inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, const float *in_low,
176 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100177 int input_offset)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100178{
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000179 ARM_COMPUTE_UNUSED(input_offset);
180
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100181 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, input_offset);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100182 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
183 return out;
184}
185
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000186/** Perform a convolve3x3 on uint8_t
187 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100188 * @param[in] in_top Pointer to the first row of the input.
189 * @param[in] in_mid Pointer to the second row of the input.
190 * @param[in] in_low Pointer to the third row of the input.
191 * @param[in] m0 First row of the filter.
192 * @param[in] m1 Second row of the filter.
193 * @param[in] m2 Third row of the filter.
194 * @param[in] input_offset (Optional) Input quantization offset.
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000195 *
196 */
197template <unsigned int stridex>
198int32x4x2_t convolve_3x3(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low,
199 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100200 int input_offset);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000201
202template <>
203inline int32x4x2_t convolve_3x3<1>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100204 int input_offset)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000205{
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000206 const int32x4_t v_input_offset = vdupq_n_s32(input_offset);
207
208 const uint8x8x2_t vtop =
209 {
210 {
211 vld1_u8(in_top),
212 vld1_u8(in_top + 8)
213 }
214 };
215 const uint8x8x2_t vmid =
216 {
217 {
218 vld1_u8(in_mid),
219 vld1_u8(in_mid + 8)
220 }
221 };
222 const uint8x8x2_t vlow =
223 {
224 {
225 vld1_u8(in_low),
226 vld1_u8(in_low + 8)
227 }
228 };
229
230 const int32x4x3_t vtop_s32 =
231 {
232 {
233 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[0])))),
234 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vtop.val[0])))),
235 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[1])))),
236 }
237 };
238 const int32x4x3_t vmid_s32 =
239 {
240 {
241 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[0])))),
242 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vmid.val[0])))),
243 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[1])))),
244 }
245 };
246 const int32x4x3_t vlow_s32 =
247 {
248 {
249 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[0])))),
250 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vlow.val[0])))),
251 vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[1])))),
252 }
253 };
254
255 int32x4x2_t out
256 {
257 {
258 vdupq_n_s32(0),
259 vdupq_n_s32(0),
260 }
261 };
262
263 // 0
264 out.val[0] = vmlaq_s32(out.val[0], vtop_s32.val[0], m0.val[0]);
265 out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vtop_s32.val[0], vtop_s32.val[1], 1), m0.val[1]);
266 out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vtop_s32.val[0], vtop_s32.val[1], 2), m0.val[2]);
267
268 out.val[0] = vmlaq_s32(out.val[0], vmid_s32.val[0], m1.val[0]);
269 out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vmid_s32.val[0], vmid_s32.val[1], 1), m1.val[1]);
270 out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vmid_s32.val[0], vmid_s32.val[1], 2), m1.val[2]);
271
272 out.val[0] = vmlaq_s32(out.val[0], vlow_s32.val[0], m2.val[0]);
273 out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vlow_s32.val[0], vlow_s32.val[1], 1), m2.val[1]);
274 out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vlow_s32.val[0], vlow_s32.val[1], 2), m2.val[2]);
275
276 // 1
277 out.val[1] = vmlaq_s32(out.val[1], vtop_s32.val[1], m0.val[0]);
278 out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vtop_s32.val[1], vtop_s32.val[2], 1), m0.val[1]);
279 out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vtop_s32.val[1], vtop_s32.val[2], 2), m0.val[2]);
280
281 out.val[1] = vmlaq_s32(out.val[1], vmid_s32.val[1], m1.val[0]);
282 out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vmid_s32.val[1], vmid_s32.val[2], 1), m1.val[1]);
283 out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vmid_s32.val[1], vmid_s32.val[2], 2), m1.val[2]);
284
285 out.val[1] = vmlaq_s32(out.val[1], vlow_s32.val[1], m2.val[0]);
286 out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vlow_s32.val[1], vlow_s32.val[2], 1), m2.val[1]);
287 out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vlow_s32.val[1], vlow_s32.val[2], 2), m2.val[2]);
288
289 return out;
290}
291
292template <>
293inline int32x4x2_t convolve_3x3<2>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low,
294 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100295 int input_offset)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000296{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100297 int32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, input_offset);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000298 out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 2), out.val[0], 1);
299 out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 0), out.val[0], 2);
300 out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 2), out.val[0], 3);
301 return out;
302}
303
304template <>
305inline int32x4x2_t convolve_3x3<3>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low,
306 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100307 int input_offset)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000308{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100309 int32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, input_offset);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000310 out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 3), out.val[0], 1);
311 return out;
312}
313
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100314/** Stores a float32x4x2_t array into a memory location.
315 *
316 * @param[in] buffer Pointer to the memory location where the values will be stored.
317 * @param[in] values Values that will be stored.
318 *
319 */
320template <unsigned int stridex>
321void store_results(float *buffer, const float32x4x2_t &values);
322
323template <>
324inline void store_results<1>(float *buffer, const float32x4x2_t &values)
325{
326 vst1q_f32(buffer, values.val[0]);
327 vst1q_f32(buffer + 4, values.val[1]);
328}
329
330template <>
331inline void store_results<2>(float *buffer, const float32x4x2_t &values)
332{
333 vst1q_f32(buffer, values.val[0]);
334}
335
336template <>
337inline void store_results<3>(float *buffer, const float32x4x2_t &values)
338{
339 vst1_f32(buffer, vget_low_f32(values.val[0]));
340}
341
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000342/** Stores a uint32_t array into a memory location.
343 *
344 * @param[in] buffer Pointer to the memory location where the values will be stored.
345 * @param[in] values Values that will be stored.
346 *
347 */
348template <unsigned int stridex>
349void store_results(int32_t *buffer, const int32x4x2_t &values);
350
351template <>
352inline void store_results<1>(int32_t *buffer, const int32x4x2_t &values)
353{
354 vst1q_s32(buffer, values.val[0]);
355 vst1q_s32(buffer + 4, values.val[1]);
356}
357
358template <>
359inline void store_results<2>(int32_t *buffer, const int32x4x2_t &values)
360{
361 vst1q_s32(buffer, values.val[0]);
362}
363
364template <>
365inline void store_results<3>(int32_t *buffer, const int32x4x2_t &values)
366{
367 vst1_s32(buffer, vget_low_s32(values.val[0]));
368}
369
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000370#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100371/** Loads a 3x3 matrix as a row (float16_t).
372 *
373 * @param[in] ptr Pointer to a float 3x3 matrix.
374 *
375 * @return The loaded matrix.
376 */
377inline float16x8x3_t load_matrix_row(const float16_t *ptr)
378{
379 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
380 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
381 const float16x8x3_t r =
382 {
383 {
384 vld1q_dup_f16(ptr),
385 vld1q_dup_f16(1 + ptr),
386 vld1q_dup_f16(2 + ptr)
387 }
388 };
389 return r;
390}
391
392/** Perform a convolve3x3 on float16.
393 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100394 * @param[in] in_top Pointer to the first row of the input.
395 * @param[in] in_mid Pointer to the second row of the input.
396 * @param[in] in_low Pointer to the third row of the input.
397 * @param[in] m0 First row of the filter.
398 * @param[in] m1 Second row of the filter.
399 * @param[in] m2 Third row of the filter.
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100400 *
401 */
402template <unsigned int stridex>
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100403float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100404
405template <>
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100406inline float16x8x2_t convolve_3x3<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100407{
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100408 const float16x8x3_t vtop =
409 {
410 {
411 vld1q_f16(in_top),
412 vld1q_f16(in_top + 8),
413 vld1q_f16(in_top + 16)
414 }
415 };
416 const float16x8x3_t vmid =
417 {
418 {
419 vld1q_f16(in_mid),
420 vld1q_f16(in_mid + 8),
421 vld1q_f16(in_mid + 16)
422 }
423 };
424 const float16x8x3_t vlow =
425 {
426 {
427 vld1q_f16(in_low),
428 vld1q_f16(in_low + 8),
429 vld1q_f16(in_low + 16)
430 }
431 };
432 float16x8x2_t out =
433 {
434 {
435 vmulq_f16(vtop.val[0], m0.val[0]),
436 vmulq_f16(vtop.val[1], m0.val[0])
437 }
438 };
439 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
440 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
441 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
442 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
443 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
444 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
445 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
446 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
447 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
448 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
449 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
450 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
451 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
452 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
453 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
454 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
455 return out;
456}
457
458template <>
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100459inline float16x8x2_t convolve_3x3<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100460{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100461 float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100462 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
463 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 2);
464 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 3);
465 return out;
466}
467
468template <>
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100469inline float16x8x2_t convolve_3x3<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2)
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100470{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100471 float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2);
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100472 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
473 return out;
474}
475
476/** Stores a float16x8x2_t array into a memory location.
477 *
478 * @param[in] buffer Pointer to the memory location where the values will be stored.
479 * @param[in] values Values that will be stored.
480 *
481 */
482template <unsigned int stridex>
483void store_results(float16_t *buffer, const float16x8x2_t &values);
484
485template <>
486inline void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
487{
488 vst1q_f16(buffer, values.val[0]);
489 vst1q_f16(buffer + 8, values.val[1]);
490}
491
492template <>
493inline void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
494{
495 vst1q_f16(buffer, values.val[0]);
496}
497
498template <>
499inline void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
500{
501 vst1_f16(buffer, vget_low_f16(values.val[0]));
502}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000503#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100504
505/** Get the number of elements processed on 3x3 convolution.
506 *
507 * @param[in] num_elems_written_per_iteration Number of elements written per iteration on 3x3 convolution.
508 *
509 * @return The number of elements processed.
510 */
511template <unsigned int stridex>
512int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration);
513
514template <>
515inline int get_input_num_elems_processed<1>(unsigned int num_elems_written_per_iteration)
516{
517 return num_elems_written_per_iteration;
518}
519
520template <>
521inline int get_input_num_elems_processed<2>(unsigned int num_elems_written_per_iteration)
522{
523 return num_elems_written_per_iteration << 1;
524}
525
526template <>
527inline int get_input_num_elems_processed<3>(unsigned int num_elems_written_per_iteration)
528{
529 return num_elems_written_per_iteration * 3;
530}
Anthony Barbier15686212017-12-12 17:17:50 +0000531inline int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration, unsigned int stridex)
532{
533 switch(stridex)
534 {
535 case 1:
536 return get_input_num_elems_processed<1>(num_elems_written_per_iteration);
537 case 2:
538 return get_input_num_elems_processed<2>(num_elems_written_per_iteration);
539 case 3:
540 return get_input_num_elems_processed<3>(num_elems_written_per_iteration);
541 default:
542 ARM_COMPUTE_ERROR("stridex not supported");
543 return 0;
544 }
545}
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100546}
547} // namespace arm_compute
Anthony Barbier15686212017-12-12 17:17:50 +0000548#endif /* __ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H__ */