blob: c358558610bdd3c304b1714830ab338d0f6fd198 [file] [log] [blame]
Michalis Spyrou7362f0d2017-10-18 17:58:22 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#ifndef __ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H__
26#define __ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H__
27
28#include "arm_compute/core/AccessWindowStatic.h"
29#include "arm_compute/core/NEON/NEFixedPoint.h"
30
31#include <arm_neon.h>
32
33namespace arm_compute
34{
35namespace detail
36{
37/** Loads a 3x3 matrix as a row (float).
38 *
39 * @param[in] ptr Pointer to a float 3x3 matrix.
40 *
41 * @return The loaded matrix.
42 */
43inline float32x4x3_t load_matrix_row(const float *ptr)
44{
45 const float32x4x3_t r =
46 {
47 {
48 vld1q_dup_f32(ptr),
49 vld1q_dup_f32(1 + ptr),
50 vld1q_dup_f32(2 + ptr)
51 }
52 };
53 return r;
54}
55
56/** Loads a 3x3 matrix as a row (qint8_t).
57 *
58 * @param[in] ptr Pointer to a qint8 3x3 matrix.
59 *
60 * @return The loaded matrix.
61 */
62inline qint8x8x3_t load_matrix_row(const qint8_t *ptr)
63{
64 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
65 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
66 const qint8x8x3_t r =
67 {
68 {
69 vld1_dup_qs8(ptr),
70 vld1_dup_qs8(1 + ptr),
71 vld1_dup_qs8(2 + ptr)
72 }
73 };
74 return r;
75}
76
77/** Perform a convolve3x3 on float32.
78 *
79 * @param[in] in_top Pointer to the first row of the input.
80 * @param[in] in_mid Pointer to the second row of the input.
81 * @param[in] in_low Pointer to the third row of the input.
82 * @param[in] m0 First row of the filter.
83 * @param[in] m1 Second row of the filter.
84 * @param[in] m2 Third row of the filter.
85 * @param[in] fixed_point_position (Optional) Fixed point position.
86 *
87 */
88template <unsigned int stridex>
89float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position);
90
91template <>
92inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
93{
94 ARM_COMPUTE_UNUSED(fixed_point_position);
95
96 const float32x4x3_t vtop =
97 {
98 {
99 vld1q_f32(in_top),
100 vld1q_f32(in_top + 4),
101 vld1q_f32(in_top + 8)
102 }
103 };
104 const float32x4x3_t vmid =
105 {
106 {
107 vld1q_f32(in_mid),
108 vld1q_f32(in_mid + 4),
109 vld1q_f32(in_mid + 8)
110 }
111 };
112 const float32x4x3_t vlow =
113 {
114 {
115 vld1q_f32(in_low),
116 vld1q_f32(in_low + 4),
117 vld1q_f32(in_low + 8)
118 }
119 };
120 float32x4x2_t out =
121 {
122 {
123 vmulq_f32(vtop.val[0], m0.val[0]),
124 vmulq_f32(vtop.val[1], m0.val[0])
125 }
126 };
127 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
128 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
129
130 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
131 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
132 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
133
134 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
135 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
136 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
137
138 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
139 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
140
141 out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
142 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
143 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
144
145 out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
146 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
147 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
148 return out;
149}
150
151template <>
152inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
153{
154 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
155 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
156 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
157 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
158 return out;
159}
160
161template <>
162inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
163{
164 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
165 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
166 return out;
167}
168
169/** Perform a convolve3x3 on qint16.
170 *
171 * @param[in] in_top Pointer to the first row of the input.
172 * @param[in] in_mid Pointer to the second row of the input.
173 * @param[in] in_low Pointer to the third row of the input.
174 * @param[in] m0 First row of the filter.
175 * @param[in] m1 Second row of the filter.
176 * @param[in] m2 Third row of the filter.
177 * @param[in] fixed_point_position (Optional) Fixed point position.
178 *
179 */
180template <unsigned int stridex>
181qint16x8x2_t convolve_3x3(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position);
182
183template <>
184inline qint16x8x2_t convolve_3x3<1>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
185{
186 ARM_COMPUTE_UNUSED(fixed_point_position);
187
188 const qint8x8x3_t vtop =
189 {
190 {
191 vld1_qs8(in_top),
192 vld1_qs8(in_top + 8),
193 vld1_qs8(in_top + 16)
194 }
195 };
196 const qint8x8x3_t vmid =
197 {
198 {
199 vld1_qs8(in_mid),
200 vld1_qs8(in_mid + 8),
201 vld1_qs8(in_mid + 16)
202 }
203 };
204 const qint8x8x3_t vlow =
205 {
206 {
207 vld1_qs8(in_low),
208 vld1_qs8(in_low + 8),
209 vld1_qs8(in_low + 16)
210 }
211 };
212 qint16x8x2_t out =
213 {
214 {
215 vmull_qs8(vtop.val[0], m0.val[0], fixed_point_position),
216 vmull_qs8(vtop.val[1], m0.val[0], fixed_point_position)
217 }
218 };
219 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vtop.val[0], vtop.val[1], 1), m0.val[1], fixed_point_position);
220 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vtop.val[0], vtop.val[1], 2), m0.val[2], fixed_point_position);
221 out.val[0] = vqmlal_qs8(out.val[0], vmid.val[0], m1.val[0], fixed_point_position);
222 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vmid.val[0], vmid.val[1], 1), m1.val[1], fixed_point_position);
223 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vmid.val[0], vmid.val[1], 2), m1.val[2], fixed_point_position);
224 out.val[0] = vqmlal_qs8(out.val[0], vlow.val[0], m2.val[0], fixed_point_position);
225 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vlow.val[0], vlow.val[1], 1), m2.val[1], fixed_point_position);
226 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vlow.val[0], vlow.val[1], 2), m2.val[2], fixed_point_position);
227 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vtop.val[1], vtop.val[2], 1), m0.val[1], fixed_point_position);
228 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vtop.val[1], vtop.val[2], 2), m0.val[2], fixed_point_position);
229 out.val[1] = vqmlal_qs8(out.val[1], vmid.val[1], m1.val[0], fixed_point_position);
230 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vmid.val[1], vmid.val[2], 1), m1.val[1], fixed_point_position);
231 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vmid.val[1], vmid.val[2], 2), m1.val[2], fixed_point_position);
232 out.val[1] = vqmlal_qs8(out.val[1], vlow.val[1], m2.val[0], fixed_point_position);
233 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vlow.val[1], vlow.val[2], 1), m2.val[1], fixed_point_position);
234 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vlow.val[1], vlow.val[2], 2), m2.val[2], fixed_point_position);
235 return out;
236}
237
238template <>
239inline qint16x8x2_t convolve_3x3<2>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
240{
241 qint16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
242 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 2), out.val[0], 1);
243 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 4), out.val[0], 2);
244 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 6), out.val[0], 3);
245 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 0), out.val[0], 4);
246 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 2), out.val[0], 5);
247 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 4), out.val[0], 6);
248 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 6), out.val[0], 7);
249 return out;
250}
251
252template <>
253inline qint16x8x2_t convolve_3x3<3>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
254{
255 qint16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
256 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 3), out.val[0], 1);
257 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 6), out.val[0], 2);
258 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 1), out.val[0], 3);
259 return out;
260}
261
262/** Stores a float32x4x2_t array into a memory location.
263 *
264 * @param[in] buffer Pointer to the memory location where the values will be stored.
265 * @param[in] values Values that will be stored.
266 *
267 */
268template <unsigned int stridex>
269void store_results(float *buffer, const float32x4x2_t &values);
270
271template <>
272inline void store_results<1>(float *buffer, const float32x4x2_t &values)
273{
274 vst1q_f32(buffer, values.val[0]);
275 vst1q_f32(buffer + 4, values.val[1]);
276}
277
278template <>
279inline void store_results<2>(float *buffer, const float32x4x2_t &values)
280{
281 vst1q_f32(buffer, values.val[0]);
282}
283
284template <>
285inline void store_results<3>(float *buffer, const float32x4x2_t &values)
286{
287 vst1_f32(buffer, vget_low_f32(values.val[0]));
288}
289
290/** Stores a qint16_t array into a memory location.
291 *
292 * @param[in] buffer Pointer to the memory location where the values will be stored.
293 * @param[in] values Values that will be stored.
294 *
295 */
296template <unsigned int stridex>
297void store_results(qint16_t *buffer, const qint16x8x2_t &values);
298
299template <>
300inline void store_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
301{
302 vst1q_qs16(buffer, values.val[0]);
303 vst1q_qs16(buffer + 8, values.val[1]);
304}
305
306template <>
307inline void store_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
308{
309 vst1q_qs16(buffer, values.val[0]);
310}
311
312template <>
313inline void store_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
314{
315 vst1_qs16(buffer, vget_low_s16(values.val[0]));
316}
317
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000318#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100319/** Loads a 3x3 matrix as a row (float16_t).
320 *
321 * @param[in] ptr Pointer to a float 3x3 matrix.
322 *
323 * @return The loaded matrix.
324 */
325inline float16x8x3_t load_matrix_row(const float16_t *ptr)
326{
327 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
328 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
329 const float16x8x3_t r =
330 {
331 {
332 vld1q_dup_f16(ptr),
333 vld1q_dup_f16(1 + ptr),
334 vld1q_dup_f16(2 + ptr)
335 }
336 };
337 return r;
338}
339
340/** Perform a convolve3x3 on float16.
341 *
342 * @param[in] in_top Pointer to the first row of the input.
343 * @param[in] in_mid Pointer to the second row of the input.
344 * @param[in] in_low Pointer to the third row of the input.
345 * @param[in] m0 First row of the filter.
346 * @param[in] m1 Second row of the filter.
347 * @param[in] m2 Third row of the filter.
348 * @param[in] fixed_point_position (Optional) Fixed point position.
349 *
350 */
351template <unsigned int stridex>
352float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
353 int fixed_point_position);
354
355template <>
356inline float16x8x2_t convolve_3x3<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
357 int fixed_point_position)
358{
359 ARM_COMPUTE_UNUSED(fixed_point_position);
360
361 const float16x8x3_t vtop =
362 {
363 {
364 vld1q_f16(in_top),
365 vld1q_f16(in_top + 8),
366 vld1q_f16(in_top + 16)
367 }
368 };
369 const float16x8x3_t vmid =
370 {
371 {
372 vld1q_f16(in_mid),
373 vld1q_f16(in_mid + 8),
374 vld1q_f16(in_mid + 16)
375 }
376 };
377 const float16x8x3_t vlow =
378 {
379 {
380 vld1q_f16(in_low),
381 vld1q_f16(in_low + 8),
382 vld1q_f16(in_low + 16)
383 }
384 };
385 float16x8x2_t out =
386 {
387 {
388 vmulq_f16(vtop.val[0], m0.val[0]),
389 vmulq_f16(vtop.val[1], m0.val[0])
390 }
391 };
392 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
393 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
394 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
395 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
396 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
397 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
398 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
399 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
400 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
401 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
402 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
403 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
404 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
405 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
406 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
407 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
408 return out;
409}
410
411template <>
412inline float16x8x2_t convolve_3x3<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
413 int fixed_point_position)
414{
415 float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
416 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
417 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 2);
418 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 3);
419 return out;
420}
421
422template <>
423inline float16x8x2_t convolve_3x3<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
424 int fixed_point_position)
425{
426 float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
427 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
428 return out;
429}
430
431/** Stores a float16x8x2_t array into a memory location.
432 *
433 * @param[in] buffer Pointer to the memory location where the values will be stored.
434 * @param[in] values Values that will be stored.
435 *
436 */
437template <unsigned int stridex>
438void store_results(float16_t *buffer, const float16x8x2_t &values);
439
440template <>
441inline void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
442{
443 vst1q_f16(buffer, values.val[0]);
444 vst1q_f16(buffer + 8, values.val[1]);
445}
446
447template <>
448inline void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
449{
450 vst1q_f16(buffer, values.val[0]);
451}
452
453template <>
454inline void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
455{
456 vst1_f16(buffer, vget_low_f16(values.val[0]));
457}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000458#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100459
460/** Get the number of elements processed on 3x3 convolution.
461 *
462 * @param[in] num_elems_written_per_iteration Number of elements written per iteration on 3x3 convolution.
463 *
464 * @return The number of elements processed.
465 */
466template <unsigned int stridex>
467int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration);
468
469template <>
470inline int get_input_num_elems_processed<1>(unsigned int num_elems_written_per_iteration)
471{
472 return num_elems_written_per_iteration;
473}
474
475template <>
476inline int get_input_num_elems_processed<2>(unsigned int num_elems_written_per_iteration)
477{
478 return num_elems_written_per_iteration << 1;
479}
480
481template <>
482inline int get_input_num_elems_processed<3>(unsigned int num_elems_written_per_iteration)
483{
484 return num_elems_written_per_iteration * 3;
485}
Anthony Barbier15686212017-12-12 17:17:50 +0000486inline int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration, unsigned int stridex)
487{
488 switch(stridex)
489 {
490 case 1:
491 return get_input_num_elems_processed<1>(num_elems_written_per_iteration);
492 case 2:
493 return get_input_num_elems_processed<2>(num_elems_written_per_iteration);
494 case 3:
495 return get_input_num_elems_processed<3>(num_elems_written_per_iteration);
496 default:
497 ARM_COMPUTE_ERROR("stridex not supported");
498 return 0;
499 }
500}
Michalis Spyrou7362f0d2017-10-18 17:58:22 +0100501}
502} // namespace arm_compute
Anthony Barbier15686212017-12-12 17:17:50 +0000503#endif /* __ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H__ */