blob: 0d694f3ec0a746d2d92ad21bc75a4af1d50c1f59 [file] [log] [blame]
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001/*
2 * Copyright (c) 2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in
14 * all copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22 * IN THE SOFTWARE.
23 */
24
25#pragma once
26
27#ifdef __ARM_FEATURE_SVE
28
29
30namespace {
31
32void sve_transpose_interleave_3VL(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height)
33{
34 size_t out_stride = 3 * height * get_vector_length<uint8_t>();
35
36 __asm__ __volatile__(
37 "ptrue p2.b\n"
38 "cmp %x[height], #0x4\n"
39 "blt 4f\n"
40 "1:" // Main row loop: Head
41 "mov x26, %x[in]\n"
42 "mov x25, %x[out]\n"
43 "add x24, x26, %x[in_stride]\n"
44 "add x23, x24, %x[in_stride]\n"
45 "add x22, x23, %x[in_stride]\n"
46 "add %x[in], x22, %x[in_stride]\n"
47 "sub %x[height], %x[height], #0x4\n"
48 "mov x21, %x[width]\n"
49 "2:" // Main row loop: Column loop
50 "mov x20, x21\n"
51 "mov x19, x25\n"
52 "whilelt p0.h, XZR, x20\n"
53 "ld1h { z27.h }, p0/Z, [x26]\n"
54 "ld1h { z26.h }, p0/Z, [x24]\n"
55 "dech x20\n"
56 "ld1h { z25.h }, p0/Z, [x23]\n"
57 "whilelt p1.h, XZR, x20\n"
58 "ld1h { z24.h }, p0/Z, [x22]\n"
59 "dech x20\n"
60 "ld1h { z23.h }, p1/Z, [x26, #1, MUL VL]\n"
61 "whilelt p0.h, XZR, x20\n"
62 "ld1h { z22.h }, p1/Z, [x24, #1, MUL VL]\n"
63 "add x25, x25, %x[out_stride]\n"
64 "ld1h { z21.h }, p0/Z, [x26, #2, MUL VL]\n"
65 "addvl x26, x26, #3\n"
66 "ld1h { z20.h }, p0/Z, [x24, #2, MUL VL]\n"
67 "addvl x24, x24, #3\n"
68 "ld1h { z19.h }, p1/Z, [x23, #1, MUL VL]\n"
69 "dech x21, ALL, MUL #3\n"
70 "ld1h { z18.h }, p0/Z, [x23, #2, MUL VL]\n"
71 "addvl x23, x23, #3\n"
72 "ld1h { z17.h }, p1/Z, [x22, #1, MUL VL]\n"
73 "cmp x21, #0x0\n"
74 "ld1h { z16.h }, p0/Z, [x22, #2, MUL VL]\n"
75 "addvl x22, x22, #3\n"
76 "st1h { z27.h }, p2, [x19]\n"
77 "st1h { z23.h }, p2, [x19, #1, MUL VL]\n"
78 "st1h { z21.h }, p2, [x19, #2, MUL VL]\n"
79 "st1h { z26.h }, p2, [x19, #3, MUL VL]\n"
80 "st1h { z22.h }, p2, [x19, #4, MUL VL]\n"
81 "st1h { z20.h }, p2, [x19, #5, MUL VL]\n"
82 "st1h { z25.h }, p2, [x19, #6, MUL VL]\n"
83 "st1h { z19.h }, p2, [x19, #7, MUL VL]\n"
84 "addvl x19, x19, #12\n"
85 "st1h { z18.h }, p2, [x19, #-4, MUL VL]\n"
86 "st1h { z24.h }, p2, [x19, #-3, MUL VL]\n"
87 "st1h { z17.h }, p2, [x19, #-2, MUL VL]\n"
88 "st1h { z16.h }, p2, [x19, #-1, MUL VL]\n"
89 "bgt 2b\n"
90 "3:" // Main row loop: Column loop skip
91 "addvl %x[out], %x[out], #12\n"
92 "cmp %x[height], #0x4\n"
93 "bge 1b\n"
94 "cbz %x[height], 8f\n"
95 "4:" // Main loop skip
96
97 "5:" // Tail row loop: Head
98 "mov x26, %x[in]\n"
99 "mov x25, %x[out]\n"
100 "add %x[in], x26, %x[in_stride]\n"
101 "sub %x[height], %x[height], #0x1\n"
102 "mov x20, %x[width]\n"
103 "6:" // Tail row loop: Column loop
104 "mov x19, x20\n"
105 "dech x20, ALL, MUL #3\n"
106 "whilelt p0.h, XZR, x19\n"
107 "ld1h { z18.h }, p0/Z, [x26]\n"
108 "dech x19\n"
109 "whilelt p0.h, XZR, x19\n"
110 "ld1h { z17.h }, p0/Z, [x26, #1, MUL VL]\n"
111 "dech x19\n"
112 "whilelt p0.h, XZR, x19\n"
113 "ld1h { z16.h }, p0/Z, [x26, #2, MUL VL]\n"
114 "st1h { z18.h }, p2, [x25]\n"
115 "addvl x26, x26, #3\n"
116 "st1h { z17.h }, p2, [x25, #1, MUL VL]\n"
117 "cmp x20, #0x0\n"
118 "st1h { z16.h }, p2, [x25, #2, MUL VL]\n"
119 "add x25, x25, %x[out_stride]\n"
120 "bgt 6b\n"
121 "7:" // Tail row loop: Column loop skip
122 "addvl %x[out], %x[out], #3\n"
123 "cmp %x[height], #0x1\n"
124 "bge 5b\n"
125 "8:" // Done
126
127 : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out)
128 : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width)
129 : "cc", "memory", "p0", "p1", "p2", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27"
130 );
131}
132
133} // anonymous namespace
134
135template<>
136void Transform<3, 1, true, VLType::SVE>(
137 float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax)
138{
139 sve_transpose_interleave_3VL(
140 reinterpret_cast<uint16_t *>(out),
141 reinterpret_cast<const uint16_t *>(in + k0 * stride + x0),
142 (xmax-x0) * sizeof(float) / 2,
143 stride * sizeof(float),
144 (kmax-k0)
145 );
146}
147
148template<>
149void Transform<3, 1, true, VLType::SVE>(
150 __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax)
151{
152 sve_transpose_interleave_3VL(
153 reinterpret_cast<uint16_t *>(out),
154 reinterpret_cast<const uint16_t *>(in + k0 * stride + x0),
155 (xmax-x0) * sizeof(__fp16) / 2,
156 stride * sizeof(__fp16),
157 (kmax-k0)
158 );
159}
160
161template<>
162void Transform<3, 1, true, VLType::SVE>(
163 double *out, const double *in, int stride, int x0, int xmax, int k0, int kmax)
164{
165 sve_transpose_interleave_3VL(
166 reinterpret_cast<uint16_t *>(out),
167 reinterpret_cast<const uint16_t *>(in + k0 * stride + x0),
168 (xmax-x0) * sizeof(double) / 2,
169 stride * sizeof(double),
170 (kmax-k0)
171 );
172}
173
174#endif