blob: a05d700c5efc6f976af03e7c59a3920ae1369f9a [file] [log] [blame]
Gunes Bayiref637392024-02-12 21:32:51 +00001/*
2 * Copyright (c) 2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifdef __aarch64__
25
26#include <arm_neon.h>
27
Gunes Bayirc1787f02024-02-22 12:44:55 +000028#if !defined(_WIN64) && !defined(__OpenBSD__)
Gunes Bayiref637392024-02-12 21:32:51 +000029#include <alloca.h>
Gunes Bayirc1787f02024-02-22 12:44:55 +000030#endif /* !defined(_WIN64) && !defined(__OpenBSD__) */
31
Gunes Bayiref637392024-02-12 21:32:51 +000032#include <cstring>
33
34#include "transform.hpp"
35#include "utils.hpp"
36
37namespace arm_gemm {
38
39namespace {
40
41// Helper function to interleave a single 4x4 block of 32-bin values
42// together.
43
44// _full version doesn't need to worry about any padding.
45static inline void transpose_block_32_full(const uint8_t * __restrict in_ptr0, const uint8_t * __restrict in_ptr1, const uint8_t * __restrict in_ptr2, const uint8_t * __restrict in_ptr3, uint8_t * __restrict out_ptr, long output_stride) {
46 uint32x4_t inputs[4];
47 uint32x4_t inters[4];
48 uint32x4_t outputs[4];
49
50 inputs[0] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr0));
51 inputs[1] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr1));
52 inputs[2] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr2));
53 inputs[3] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr3));
54
55 inters[0] = vzip1q_u32(inputs[0], inputs[2]);
56 inters[1] = vzip2q_u32(inputs[0], inputs[2]);
57 inters[2] = vzip1q_u32(inputs[1], inputs[3]);
58 inters[3] = vzip2q_u32(inputs[1], inputs[3]);
59
60 outputs[0] = vzip1q_u32(inters[0], inters[2]);
61 outputs[1] = vzip2q_u32(inters[0], inters[2]);
62 outputs[2] = vzip1q_u32(inters[1], inters[3]);
63 outputs[3] = vzip2q_u32(inters[1], inters[3]);
64
65 vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr), outputs[0]);
66 vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride), outputs[1]);
67 vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*2), outputs[2]);
68 vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*3), outputs[3]);
69}
70
71// _part version: Only read "bytes_in" bytes, not a full vector. Only write
72// out 4-byte blocks that have some live content (if bytes_in is not a
73// multiple of 4 there will some padding in each 4-block)
74static inline void transpose_block_32_part(const uint8_t *in_ptr0, const uint8_t *in_ptr1, const uint8_t *in_ptr2, const uint8_t *in_ptr3, uint8_t *out_ptr, long bytes_in, long output_stride) {
75 uint32x4_t inputs[4];
76 uint32x4_t inters[4];
77 uint32x4_t outputs[4];
78 uint8_t scratch[16] = {0};
79
80 long num_outs = iceildiv<long>(bytes_in, 4);
81
82 memcpy(scratch, in_ptr0, bytes_in);
83 inputs[0] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
84 memcpy(scratch, in_ptr1, bytes_in);
85 inputs[1] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
86 memcpy(scratch, in_ptr2, bytes_in);
87 inputs[2] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
88 memcpy(scratch, in_ptr3, bytes_in);
89 inputs[3] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
90
91 inters[0] = vzip1q_u32(inputs[0], inputs[2]);
92 inters[1] = vzip2q_u32(inputs[0], inputs[2]);
93 inters[2] = vzip1q_u32(inputs[1], inputs[3]);
94 inters[3] = vzip2q_u32(inputs[1], inputs[3]);
95
96 outputs[0] = vzip1q_u32(inters[0], inters[2]);
97 outputs[1] = vzip2q_u32(inters[0], inters[2]);
98 outputs[2] = vzip1q_u32(inters[1], inters[3]);
99 outputs[3] = vzip2q_u32(inters[1], inters[3]);
100
101 do {
102 vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr), outputs[0]);
103 if (num_outs < 2)
104 break;
105 vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride), outputs[1]);
106 if (num_outs < 3)
107 break;
108 vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*2), outputs[2]);
109 if (num_outs < 4)
110 break;
111 vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*3), outputs[3]);
112 } while (0);
113}
114
115template<unsigned N>
116struct Unroll {
117 template<typename F>
118 static void run(F f) {
119 Unroll<N-1>::run(f);
120 f(N-1);
121 }
122};
123
124template<>
125struct Unroll<0> {
126 template<typename F>
127 static void run(F) {
128 }
129};
130
131// Interleave some multiple of 4 rows together.
132//
133// The template parameter BLOCKS controls the size of the inner loop - each BLOCK is 4 rows.
134// The function parameter interleave_multiple controls the number of times the inner loop is run.
135
136// The total interleave depth for a given run is therefore BLOCKS * interleave_multiple * 4.
137template<unsigned BLOCKS>
138void a64_interleave_1x4(uint8_t *out, const uint8_t *in, long width, long in_stride, long height, long interleave_multiple) {
139 const long total_interleave_depth = BLOCKS * 4 * interleave_multiple;
140 constexpr long loop_interleave_depth = BLOCKS * 4;
141
142 uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width));
143
144 if (height % total_interleave_depth) {
145 memset(pad_row, 0, width);
146 }
147
148 // Outer loop: process blocks of total_interleave_depth rows at a time.
149 for (long y0_base=0; y0_base<height; y0_base+=total_interleave_depth) {
150 // Middle loop: process each "interlave_multiple" block of rows.
151 for (long block=0; block<interleave_multiple; block++) {
152 const long y0 = y0_base + (block * loop_interleave_depth);
153 uint8_t *out_ptr = out + (block * loop_interleave_depth * 4); // 4 is the blocking depth (we interleave 4 bytes at a time from each input)
154
155 // Create and set up input row pointers. The idea is that these
156 // should entirely fit in the register file, so we don't have to
157 // repeatedly load them (or perform the padding check)
158 const uint8_t *in_ptrs[loop_interleave_depth];
159 Unroll<loop_interleave_depth>::run( [&](unsigned y) {
160 in_ptrs[y] = (y+y0 < height) ? in + ((y+y0) * in_stride) : pad_row;
161 });
162
163 long bytes_left = width;
164 // Process full vectors using transpose_block_32_full()
165 while (bytes_left >= 16) { // 16 is the vector length in bytes
166 Unroll<BLOCKS>::run( [&](unsigned u) {
167 transpose_block_32_full(in_ptrs[u*4 + 0], in_ptrs[u*4 + 1], in_ptrs[u*4 + 2], in_ptrs[u*4 + 3],
168 out_ptr + 16*u, total_interleave_depth * 4); // 4 is the blocking depth
169 });
170
171 Unroll<loop_interleave_depth>::run( [&](unsigned y) {
172 in_ptrs[y] += 16; // 16 is the vector length in bytes
173 });
174
175 out_ptr += total_interleave_depth * 16; // 16 is the vector length in bytes
176 bytes_left -= 16; // 16 is the vector length in bytes
177 }
178
179 // Process any remaining bytes using transpose_block_32_part()
180 if (bytes_left) {
181 Unroll<BLOCKS>::run( [&](unsigned u) {
182 transpose_block_32_part(in_ptrs[u*4 + 0], in_ptrs[u*4 + 1], in_ptrs[u*4 + 2], in_ptrs[u*4 + 3],
183 out_ptr + 16*u, bytes_left, total_interleave_depth * 4);
184 });
185 }
186 }
187
188 // Update "out" pointer for next set of total_interleave_depth rows
189 out += total_interleave_depth * roundup<long>(width, 4);
190 }
191}
192
193} // anonymous namespace
194
195template<>
196void Transform<16, 4, false, VLType::None>(
197 uint8_t *out, const uint8_t *in, int stride, int y0, int ymax, int x0, int xmax)
198{
199 a64_interleave_1x4<4>(
200 reinterpret_cast<uint8_t *>(out),
201 reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
202 (xmax - x0),
203 stride,
204 (ymax - y0),
205 1
206 );
207}
208
209template<>
210void Transform<16, 4, false, VLType::None>(
211 int8_t *out, const int8_t *in, int stride, int y0, int ymax, int x0, int xmax)
212{
213 a64_interleave_1x4<4>(
214 reinterpret_cast<uint8_t *>(out),
215 reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
216 (xmax - x0),
217 stride,
218 (ymax - y0),
219 1
220 );
221}
222
223template<>
224void Transform<12, 1, false, VLType::None>(
225 float *out, const float *in, int stride, int y0, int ymax, int x0, int xmax)
226{
227 a64_interleave_1x4<3>(
228 reinterpret_cast<uint8_t *>(out),
229 reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
230 (xmax - x0) * sizeof(float),
231 stride * sizeof(float),
232 (ymax - y0),
233 1
234 );
235}
236
237template<>
238void Transform<16, 1, false, VLType::None>(
239 float *out, const float *in, int stride, int y0, int ymax, int x0, int xmax)
240{
241 a64_interleave_1x4<4>(
242 reinterpret_cast<uint8_t *>(out),
243 reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
244 (xmax - x0) * sizeof(float),
245 stride * sizeof(float),
246 (ymax - y0),
247 1
248 );
249}
250
251template<>
252void Transform<24, 1, false, VLType::None>(
253 float *out, const float *in, int stride, int y0, int ymax, int x0, int xmax)
254{
255 a64_interleave_1x4<3>(
256 reinterpret_cast<uint8_t *>(out),
257 reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
258 (xmax - x0) * sizeof(float),
259 stride * sizeof(float),
260 (ymax - y0),
261 2
262 );
263}
264
265} // namespace arm_gemm
266
267#endif // __aarch64__