blob: 4f25da2877a834f36314bbb30657c77360d92a32 [file] [log] [blame]
Viet-Hoa Do03b29712022-06-01 11:47:14 +01001/*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#pragma once
26
27// Implementations of interleave functions
28// These must be included with a "namespace arm_gemm" block.
29
30/*
31 * Core function that does heavy lifting - interleave 'int_by' rows of width 'width' together.
32 *
33 * 'height' indicates the actual number of rows to interleave, so if it's less than int_by then the remaining
34 * entries are padded (note that this is "GEMM" padding rather than convolution padding, so there is no need to pad
35 * with a particular value.
36 *
37 * Note that it is not expected for this templated version to ever be used - all cases that matter should be
38 * explicitly specialized with an optimized implementation.
39 */
40template<unsigned int height_vectors, unsigned int block, VLType vlt, bool integrate_sums, typename TIn, typename TOut>
41void interleave_block( TOut * &out, const TIn * const *in, size_t width, size_t height, size_t row_offset, bool first) {
42 const unsigned int int_by = height_vectors * (vlt == VLType::SVE ? get_vector_length<TOut>() / block :
43 (vlt == VLType::SME ? sme::get_vector_length<TOut>() / block : 1 ));
44
45 std::vector<int32_t> the_sums;
46
47 if (integrate_sums) {
48 the_sums = std::vector<int32_t>(int_by, 0);
49
50 if (!first) {
51 // In 'integrate sums' mode, we dump the sums at the end on each pass.
52
53 // On the last pass this is correct, but on other passes it is not -
54 // so on the subsequent pass we need to take the output written by
55 // the previous pass as starting point for the sums, and then
56 // overwrite them with new interleaved data.
57 int32_t *out_int32 = reinterpret_cast<int32_t *>(out);
58
59 // Rewind pointer to where we wrote out the sums last time.
60 out_int32 -= int_by;
61
62 // Restore the running sums.
63 memcpy(the_sums.data(), out_int32, int_by * sizeof(int32_t));
64
65 // Update the "real" pointer so that the next output will clobber the old sums.
66 out = reinterpret_cast<TOut *>(out_int32);
67 }
68 }
69
70 for (unsigned int pos=0; pos<width; pos+=block) {
71 for (unsigned int row=0; row<int_by; row++) {
72 // Row out of range - pad 'block' entries.
73 if (row >= height) {
74 for (unsigned int col=0; col<block; col++) {
75 *out++ = 0;
76 }
77 continue;
78 }
79
80 for (unsigned int col=0; col<block; col++) {
81 // Column out of range - pad a single entry
82 if (pos + col >= width) {
83 *out++ = 0;
84 continue;
85 }
86
87 if (integrate_sums) {
88 the_sums[row] += in[row][row_offset + pos + col];
89 }
90
91 *out++ = in[row][row_offset + pos + col];
92 }
93 }
94 }
95
96 if (integrate_sums) {
97 int32_t *out_int32 = reinterpret_cast<int32_t *>(out);
98
99 memcpy(out_int32, the_sums.data(), int_by * sizeof(int32_t));
100
101 out = reinterpret_cast<TOut *>(out_int32 + int_by);
102 }
103}
104
105template<unsigned int height_vectors, unsigned int block, VLType vlt, typename TOut>
106inline void FixupRowSums(TOut * &out, const int32_t row_sum_multiplier) {
107 const unsigned int height = height_vectors * (vlt == VLType::SVE ? get_vector_length<TOut>() / block :
108 (vlt == VLType::SME ? sme::get_vector_length<TOut>() / block : 1 ));
109
110 // If we are integrating row sums, we need to do some fix up, depending on whether the multiplier is non-zero or not.
111 if (row_sum_multiplier) {
112 // Non-zero: interleave_block<>() will have done the sums, so 'out' will point to the start of the
113 // next block (post sums).
114 // We need to go back and apply the multiplier to the computed sums. We don't need to change 'out'.
115 int32_t *out_int32 = reinterpret_cast<int32_t *>(out);
116
117 out_int32 -= height;
118 for (unsigned int i=0; i<height; i++) {
119 out_int32[i] *= row_sum_multiplier;
120 }
121 } else {
122 // Zero: interleave_block<>() will *not* have done the sums, so 'out' will point to the start of the
123 // sum block. We need to insert the (zero) sums, and advance 'out'.
124 int32_t *out_int32 = reinterpret_cast<int32_t *>(out);
125
126 for (unsigned int i=0; i<height; i++) {
127 out_int32[i] = 0;
128 }
129
130 out_int32 += height;
131
132 out = reinterpret_cast<TOut *>(out_int32);
133 }
134}
135
136template<unsigned int height_vectors, unsigned int block, VLType vlt, typename TIn, typename TOut>
137void IndirectInterleave(TOut *out, const TIn * const * const *ptr, unsigned int stringlen,
138 unsigned int rounded_stringlen, const unsigned int y0, const unsigned int ymax,
139 const unsigned int k0, const unsigned int kmax, bool integrate_sums,
140 const int32_t row_sum_multiplier) {
141 const unsigned int height = height_vectors * (vlt == VLType::SVE ? get_vector_length<TOut>() / block :
142 (vlt == VLType::SME ? sme::get_vector_length<TOut>() / block : 1 ));
143
144 // 'interleave_block' implementations are entitled to read a pointer for each row they handle from the input
145 // pointer array, even for out of range rows (although they must not subsequently dereference those pointers for
146 // out of range rows). This allows interleave_block to use techniques like row predication, or loading all
147 // pointers and conditionally overriding the out of range ones.
148
149 // This is problematic in the "pure" indirect case when we get to the last rows, where it can lead to out of
150 // range reads. Avoid this with a local buffer to use in last-rows cases. Use alloca as a std::vector can be
151 // expensive in highly threaded scenarios.
152 const TIn **row_ptrs = reinterpret_cast<const TIn **>(alloca(height * sizeof(const TIn *)));
153
154 // Figure out the starting position based on k0 (with rounded length)
155 unsigned int start_string = k0 / rounded_stringlen;
156 unsigned int start_stringpos = k0 % rounded_stringlen;
157
158 // Process blocks of 'height' height...
159 for (unsigned int ybase = y0; ybase < ymax; ybase+=height) {
160 // Height to process
161 unsigned int active_height = std::min(ymax - ybase, height);
162
163 // Track our progress through the various strings
164 unsigned int k_left = (kmax - k0);
165 unsigned int string = start_string;
166 unsigned int stringpos = start_stringpos;
167
168 bool first = true;
169
170 // Prepare to call 'interleave_block' above for each string encompassed by K range
171 while (k_left > 0) {
172 // Width to process - and the width we will generate (with padding)
173 unsigned int in_width = std::min(k_left, stringlen - stringpos);
174 unsigned int out_width = std::min(k_left, rounded_stringlen - stringpos);
175
176 const TIn * const *row_base = ptr[string] + ybase;
177
178 // If not all rows are valid, copy the ones that are into local array (see above comment).
179 if (active_height < height) {
180 for (unsigned int i=0; i<active_height; i++) {
181 row_ptrs[i] = ptr[string][ybase + i];
182 }
183
184 row_base = row_ptrs;
185 }
186
187 // 'integrate_sums' is a function parameter rather than a template parameter to prevent duplicating too
188 // much code. However, integrated sums make no sense for non-integral types and won't ever be
189 // requested. So put a type trait check here to avoid generating pointless code.
190 if (std::is_integral<TOut>::value && integrate_sums && row_sum_multiplier) {
191 interleave_block<height_vectors, block, vlt, true>(out, row_base, in_width, active_height, stringpos, first);
192 } else {
193 interleave_block<height_vectors, block, vlt, false>(out, row_base, in_width, active_height, stringpos, first);
194 }
195
196 k_left -= out_width;
197 string++;
198 stringpos=0;
199 first=false;
200 }
201
202 if (std::is_integral<TOut>::value && integrate_sums) {
203 FixupRowSums<height_vectors, block, vlt>(out, row_sum_multiplier);
204 }
205 }
206}
207
208template<unsigned int height_vectors, unsigned int block, VLType vlt, typename TIn, typename TOut>
209void ConvolutionInterleave(TOut *out, const TIn *in, size_t in_stride, const convolver<TIn> &conv, const unsigned int rounded_stringlen,
210 const unsigned int y0, const unsigned int ymax, const unsigned int k0, const unsigned int kmax, bool integrate_sums, const int32_t row_sum_multiplier) {
211 const unsigned int height = height_vectors * (vlt == VLType::SVE ? get_vector_length<TOut>() / block :
212 (vlt == VLType::SME ? sme::get_vector_length<TOut>() / block : 1 ));
213 auto conv_cols = conv.process_columns(in, in_stride, k0, kmax, rounded_stringlen);
214
215 // Use alloca here as a std::vector can be expensive in highly threaded scenarios.
216 const TIn **row_ptrs = reinterpret_cast<const TIn **>(alloca(height * sizeof(const TIn *)));
217
218 for (unsigned int ybase = y0; ybase < ymax; ybase += height) {
219 // How many of the rows are active - the rest will get padded in interleave_block.
220 unsigned int active_height = std::min(ymax - ybase, height);
221 bool first = true;
222
223 auto conv_rows = conv_cols.process_rows(ybase, active_height);
224
225 while (!conv_rows.finished()) {
226 unsigned int width, offset;
227
228 // Get next set of parameters
229 std::tie(width, offset) = conv_rows.next_block(row_ptrs);
230
231 // Perform the interleave
232 if (std::is_integral<TOut>::value && integrate_sums && row_sum_multiplier) {
233 interleave_block<height_vectors, block, vlt, true>(out, row_ptrs, width, active_height, offset, first);
234 } else {
235 interleave_block<height_vectors, block, vlt, false>(out, row_ptrs, width, active_height, offset, first);
236 }
237
238 first=false;
239 }
240
241 if (std::is_integral<TOut>::value && integrate_sums) {
242 FixupRowSums<height_vectors, block, vlt>(out, row_sum_multiplier);
243 }
244 }
245}
246
247template<unsigned int height_vectors, unsigned int block, VLType vlt, typename TIn, typename TOut>
248void Interleave(TOut *out, const TIn *in, size_t in_stride, const unsigned int y0, const unsigned int ymax, const unsigned int k0, const unsigned int kmax, bool integrate_sums, const int32_t row_sum_multiplier) {
249 const unsigned int height = height_vectors * (vlt == VLType::SVE ? get_vector_length<TOut>() / block :
250 (vlt == VLType::SME ? sme::get_vector_length<TOut>() / block : 1 ));
251 // Use alloca here as a std::vector can be expensive in highly threaded scenarios.
252 const TIn **row_ptrs = reinterpret_cast<const TIn **>(alloca(height * sizeof(const TIn *)));
253
254 const unsigned int width=kmax-k0;
255
256 for (unsigned int y=y0; y<ymax; y+=height) {
257 for (unsigned int r=0; r<height; r++) {
258 row_ptrs[r] = in + ((y + r) * in_stride);
259 }
260
261 if (std::is_integral<TOut>::value && integrate_sums && row_sum_multiplier) {
262 interleave_block<height_vectors, block, vlt, true>(out, row_ptrs, width, std::min(height, ymax-y), k0, true);
263 } else {
264 interleave_block<height_vectors, block, vlt, false>(out, row_ptrs, width, std::min(height, ymax-y), k0, true);
265 }
266
267 if (std::is_integral<TOut>::value && integrate_sums) {
268 FixupRowSums<height_vectors, block, vlt>(out, row_sum_multiplier);
269 }
270 }
271}