blob: 88eaa5f07cae593c13d34a7982b115682a853c7f [file] [log] [blame]
Georgios Pinitas48b3ef82019-10-14 19:03:09 +01001/*
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +01002 * Copyright (c) 2019-2020 Arm Limited.
Georgios Pinitas48b3ef82019-10-14 19:03:09 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#pragma once
25
26#ifdef __aarch64__
27
28template<>
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010029void MergeResults<4, 4, false>(int32_t *out, const int32_t *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const int32_t *bias, Activation , bool append)
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010030{
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010031 const int32_t *inptr = in;
Georgios Pinitasc7b183a2020-03-06 18:12:09 +000032 int32_t nullbias[4];
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010033
34
35 if (!append && !bias)
36 {
37 memset(nullbias, 0, (4 * sizeof(int32_t)));
38 }
39
40 for (int y=y0; y<ymax; y+=4)
41 {
42 int32_t *outptr0 = out + (y * ldout) + x0;
43 int32_t *outptr1 = outptr0 + ldout;
44 int32_t *outptr2 = outptr1 + ldout;
45 int32_t *outptr3 = outptr2 + ldout;
46
47 const int height = ymax - y;
48
49 for (int i=x0; i<xmax; i+=4)
50 {
51 if (append)
52 {
53 switch(height)
54 {
55 case 1:
56 {
57 if ((i+3) >= xmax)
58 {
59 for (int xi=0; xi<3; xi++)
60 {
61 if ((i+xi) < xmax)
62 {
63 *outptr0 += inptr[xi];
64 outptr0++;
65 }
66 }
67 inptr += 16;
68 } else {
69 /* Optimized routine to copy an entire block */
70 __asm __volatile (
71 "ldr q2, [%[outptr0]]\n"
72 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
73 "ldr q10, [%[inptr]]\n"
74 "prfm PLDL1KEEP, [%[outptr0], #0x20]\n"
75 "add %[inptr], %[inptr], #0x40\n"
76 "add v10.4s, v10.4s, v2.4s\n"
77 "str q10, [%[outptr0]]\n"
78 "add %[outptr0], %[outptr0], #0x10\n"
79 : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
80 [inptr] "+r" (inptr)
81 :
82 : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
83 );
84 }
85 }
86 break;
87
88 case 2:
89 {
90 if ((i+3) >= xmax)
91 {
92 for (int xi=0; xi<3; xi++)
93 {
94 if ((i+xi) < xmax)
95 {
96 *outptr0 += inptr[xi];
97 outptr0++;
98 *outptr1 += inptr[xi + 4];
99 outptr1++;
100 }
101 }
102 inptr += 16;
103 } else {
104 /* Optimized routine to copy an entire block */
105 __asm __volatile (
106 "ldr q2, [%[outptr0]]\n"
107 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
108 "ldr q10, [%[inptr]]\n"
109 "prfm PLDL1KEEP, [%[outptr0], #0x20]\n"
110 "ldr q3, [%[outptr1]]\n"
111 "prfm PLDL1KEEP, [%[outptr1], #0x20]\n"
112 "add v10.4s, v10.4s, v2.4s\n"
113 "ldr q11, [%[inptr], #0x10]\n"
114 "add %[inptr], %[inptr], #0x40\n"
115 "add v11.4s, v11.4s, v3.4s\n"
116 "str q10, [%[outptr0]]\n"
117 "add %[outptr0], %[outptr0], #0x10\n"
118 "str q11, [%[outptr1]]\n"
119 "add %[outptr1], %[outptr1], #0x10\n"
120 : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
121 [inptr] "+r" (inptr)
122 :
123 : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
124 );
125 }
126 }
127 break;
128
129 case 3:
130 {
131 if ((i+3) >= xmax)
132 {
133 for (int xi=0; xi<3; xi++)
134 {
135 if ((i+xi) < xmax)
136 {
137 *outptr0 += inptr[xi];
138 outptr0++;
139 *outptr1 += inptr[xi + 4];
140 outptr1++;
141 *outptr2 += inptr[xi + 8];
142 outptr2++;
143 }
144 }
145 inptr += 16;
146 } else {
147 /* Optimized routine to copy an entire block */
148 __asm __volatile (
149 "ldr q2, [%[outptr0]]\n"
150 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
151 "ldr q10, [%[inptr]]\n"
152 "prfm PLDL1KEEP, [%[outptr0], #0x20]\n"
153 "ldr q3, [%[outptr1]]\n"
154 "prfm PLDL1KEEP, [%[outptr1], #0x20]\n"
155 "add v10.4s, v10.4s, v2.4s\n"
156 "ldr q11, [%[inptr], #0x10]\n"
157 "ldr q4, [%[outptr2]]\n"
158 "prfm PLDL1KEEP, [%[outptr2], #0x20]\n"
159 "ldr q12, [%[inptr], #0x20]\n"
160 "add %[inptr], %[inptr], #0x40\n"
161 "add v11.4s, v11.4s, v3.4s\n"
162 "str q10, [%[outptr0]]\n"
163 "add %[outptr0], %[outptr0], #0x10\n"
164 "add v12.4s, v12.4s, v4.4s\n"
165 "str q11, [%[outptr1]]\n"
166 "add %[outptr1], %[outptr1], #0x10\n"
167 "str q12, [%[outptr2]]\n"
168 "add %[outptr2], %[outptr2], #0x10\n"
169 : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
170 [inptr] "+r" (inptr)
171 :
172 : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
173 );
174 }
175 }
176 break;
177
178 default:
179 case 4:
180 {
181 if ((i+3) >= xmax)
182 {
183 for (int xi=0; xi<3; xi++)
184 {
185 if ((i+xi) < xmax)
186 {
187 *outptr0 += inptr[xi];
188 outptr0++;
189 *outptr1 += inptr[xi + 4];
190 outptr1++;
191 *outptr2 += inptr[xi + 8];
192 outptr2++;
193 *outptr3 += inptr[xi + 12];
194 outptr3++;
195 }
196 }
197 inptr += 16;
198 } else {
199 /* Optimized routine to copy an entire block */
200 __asm __volatile (
201 "ldr q2, [%[outptr0]]\n"
202 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
203 "ldr q10, [%[inptr]]\n"
204 "prfm PLDL1KEEP, [%[outptr0], #0x20]\n"
205 "ldr q3, [%[outptr1]]\n"
206 "prfm PLDL1KEEP, [%[outptr1], #0x20]\n"
207 "add v10.4s, v10.4s, v2.4s\n"
208 "ldr q11, [%[inptr], #0x10]\n"
209 "ldr q4, [%[outptr2]]\n"
210 "prfm PLDL1KEEP, [%[outptr2], #0x20]\n"
211 "ldr q12, [%[inptr], #0x20]\n"
212 "prfm PLDL1KEEP, [%[outptr3], #0x20]\n"
213 "add v11.4s, v11.4s, v3.4s\n"
214 "str q10, [%[outptr0]]\n"
215 "ldr q5, [%[outptr3]]\n"
216 "add %[outptr0], %[outptr0], #0x10\n"
217 "add v12.4s, v12.4s, v4.4s\n"
218 "str q11, [%[outptr1]]\n"
219 "ldr q13, [%[inptr], #0x30]\n"
220 "add %[outptr1], %[outptr1], #0x10\n"
221 "add %[inptr], %[inptr], #0x40\n"
222 "str q12, [%[outptr2]]\n"
223 "add %[outptr2], %[outptr2], #0x10\n"
224 "add v13.4s, v13.4s, v5.4s\n"
225 "str q13, [%[outptr3]]\n"
226 "add %[outptr3], %[outptr3], #0x10\n"
227 : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
228 [inptr] "+r" (inptr)
229 :
230 : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
231 );
232 }
233 }
234 break;
235
236
237 }
238 }
239 else
240 {
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100241 const int32_t *biasptr = bias ? bias + i : nullbias;
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100242
243 switch(height)
244 {
245 case 1:
246 {
247 if ((i+3) >= xmax)
248 {
249 for (int xi=0; xi<3; xi++)
250 {
251 if ((i+xi) < xmax)
252 {
253 *outptr0 = biasptr[xi] + inptr[xi];
254 outptr0++;
255 }
256 }
257 inptr += 16;
258 } else {
259 /* Optimized routine to copy an entire block */
260 __asm __volatile (
261 "ldr q2, [%[biasptr]]\n"
262 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
263 "ldr q11, [%[inptr]]\n"
264 "prfm PSTL1KEEP, [%[outptr0], #0x20]\n"
265 "add %[inptr], %[inptr], #0x40\n"
266 "add v11.4s, v11.4s, v2.4s\n"
267 "str q11, [%[outptr0]]\n"
268 "add %[outptr0], %[outptr0], #0x10\n"
269 : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
270 [inptr] "+r" (inptr)
271 : [biasptr] "r" (biasptr)
272 : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
273 );
274 }
275 }
276 break;
277
278 case 2:
279 {
280 if ((i+3) >= xmax)
281 {
282 for (int xi=0; xi<3; xi++)
283 {
284 if ((i+xi) < xmax)
285 {
286 *outptr0 = biasptr[xi] + inptr[xi];
287 outptr0++;
288 *outptr1 = biasptr[xi] + inptr[xi + 4];
289 outptr1++;
290 }
291 }
292 inptr += 16;
293 } else {
294 /* Optimized routine to copy an entire block */
295 __asm __volatile (
296 "ldr q2, [%[biasptr]]\n"
297 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
298 "ldr q11, [%[inptr]]\n"
299 "prfm PSTL1KEEP, [%[outptr0], #0x20]\n"
300 "ldr q12, [%[inptr], #0x10]\n"
301 "prfm PSTL1KEEP, [%[outptr1], #0x20]\n"
302 "add v11.4s, v11.4s, v2.4s\n"
303 "add %[inptr], %[inptr], #0x40\n"
304 "add v12.4s, v12.4s, v2.4s\n"
305 "str q11, [%[outptr0]]\n"
306 "add %[outptr0], %[outptr0], #0x10\n"
307 "str q12, [%[outptr1]]\n"
308 "add %[outptr1], %[outptr1], #0x10\n"
309 : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
310 [inptr] "+r" (inptr)
311 : [biasptr] "r" (biasptr)
312 : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
313 );
314 }
315 }
316 break;
317
318 case 3:
319 {
320 if ((i+3) >= xmax)
321 {
322 for (int xi=0; xi<3; xi++)
323 {
324 if ((i+xi) < xmax)
325 {
326 *outptr0 = biasptr[xi] + inptr[xi];
327 outptr0++;
328 *outptr1 = biasptr[xi] + inptr[xi + 4];
329 outptr1++;
330 *outptr2 = biasptr[xi] + inptr[xi + 8];
331 outptr2++;
332 }
333 }
334 inptr += 16;
335 } else {
336 /* Optimized routine to copy an entire block */
337 __asm __volatile (
338 "ldr q2, [%[biasptr]]\n"
339 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
340 "ldr q11, [%[inptr]]\n"
341 "prfm PSTL1KEEP, [%[outptr0], #0x20]\n"
342 "ldr q12, [%[inptr], #0x10]\n"
343 "prfm PSTL1KEEP, [%[outptr1], #0x20]\n"
344 "add v11.4s, v11.4s, v2.4s\n"
345 "ldr q13, [%[inptr], #0x20]\n"
346 "prfm PSTL1KEEP, [%[outptr2], #0x20]\n"
347 "add v12.4s, v12.4s, v2.4s\n"
348 "add %[inptr], %[inptr], #0x40\n"
349 "add v13.4s, v13.4s, v2.4s\n"
350 "str q11, [%[outptr0]]\n"
351 "add %[outptr0], %[outptr0], #0x10\n"
352 "str q12, [%[outptr1]]\n"
353 "add %[outptr1], %[outptr1], #0x10\n"
354 "str q13, [%[outptr2]]\n"
355 "add %[outptr2], %[outptr2], #0x10\n"
356 : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
357 [inptr] "+r" (inptr)
358 : [biasptr] "r" (biasptr)
359 : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
360 );
361 }
362 }
363 break;
364
365 default:
366 case 4:
367 {
368 if ((i+3) >= xmax)
369 {
370 for (int xi=0; xi<3; xi++)
371 {
372 if ((i+xi) < xmax)
373 {
374 *outptr0 = biasptr[xi] + inptr[xi];
375 outptr0++;
376 *outptr1 = biasptr[xi] + inptr[xi + 4];
377 outptr1++;
378 *outptr2 = biasptr[xi] + inptr[xi + 8];
379 outptr2++;
380 *outptr3 = biasptr[xi] + inptr[xi + 12];
381 outptr3++;
382 }
383 }
384 inptr += 16;
385 } else {
386 /* Optimized routine to copy an entire block */
387 __asm __volatile (
388 "ldr q2, [%[biasptr]]\n"
389 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
390 "ldr q11, [%[inptr]]\n"
391 "prfm PSTL1KEEP, [%[outptr0], #0x20]\n"
392 "ldr q12, [%[inptr], #0x10]\n"
393 "prfm PSTL1KEEP, [%[outptr1], #0x20]\n"
394 "add v11.4s, v11.4s, v2.4s\n"
395 "ldr q13, [%[inptr], #0x20]\n"
396 "ldr q14, [%[inptr], #0x30]\n"
397 "prfm PSTL1KEEP, [%[outptr2], #0x20]\n"
398 "add v12.4s, v12.4s, v2.4s\n"
399 "str q11, [%[outptr0]]\n"
400 "add v13.4s, v13.4s, v2.4s\n"
401 "add %[outptr0], %[outptr0], #0x10\n"
402 "add v14.4s, v14.4s, v2.4s\n"
403 "str q12, [%[outptr1]]\n"
404 "add %[outptr1], %[outptr1], #0x10\n"
405 "prfm PSTL1KEEP, [%[outptr3], #0x20]\n"
406 "add %[inptr], %[inptr], #0x40\n"
407 "str q13, [%[outptr2]]\n"
408 "add %[outptr2], %[outptr2], #0x10\n"
409 "str q14, [%[outptr3]]\n"
410 "add %[outptr3], %[outptr3], #0x10\n"
411 : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
412 [inptr] "+r" (inptr)
413 : [biasptr] "r" (biasptr)
414 : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
415 );
416 }
417 }
418 break;
419
420
421 }
422 }
423 }
424 }
425}
426
427#endif // __aarch64__