blob: 9cdf15a4aff3c18c6ed80150cd5c11198de489ee [file] [log] [blame]
ramelg01a1f78512022-06-29 16:28:10 +01001/*
2 * Copyright (c) 2022 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include <cstddef>
26#include <arm_neon.h>
27
28namespace arm_conv {
29namespace winograd {
30namespace weight_transform {
31
32void arm_fp32_2x2_5x5(
33 unsigned int n_channels,
34 const float *inptr, const size_t ld_weight_row, const size_t ld_weight_col,
35 float *outptr, const size_t matrix_stride
36)
37{
38#ifdef __aarch64__
39 // For each output channel
40 for (; n_channels >= 4; n_channels -= 4)
41 {
42 // Matrices used and computed in this kernel
43 float32x4_t w[5][5], Ww[6][5], V[6][6];
44
45 // Read weights
46 for (int i = 0; i < 5; i++)
47 {
48 for (int j = 0; j < 5; j++)
49 {
50 w[i][j] = vld1q_f32(inptr + i*ld_weight_row + j*ld_weight_col);
51 }
52 }
53
54 // Compute the matrix W w
55 for (int j = 0; j < 5; j++)
56 {
57 // Ww[0][j] = w[0][j]/4.0f;
58 Ww[0][j] = vmulq_n_f32(w[0][j], 1.0f/4.0f);
59
60 // Ww[1][j] = -( w[0][j] + w[1][j] + w[2][j] + w[3][j] + w[4][j])/6.0f;
61 Ww[1][j] = vmulq_n_f32(
62 vaddq_f32(
63 vaddq_f32(
64 vaddq_f32(w[1][j], w[0][j]),
65 vaddq_f32(w[3][j], w[2][j])
66 ),
67 w[4][j]
68 ),
69 -1.0f/6.0f
70 );
71
72 // Ww[2][j] = +(-w[0][j] + w[1][j] - w[2][j] + w[3][j] - w[4][j])/6.0f;
73 // Ww[2][j] = ((w[1][j] - w[0][j]) + (w[3][j] - w[2][j]) - w[4][j])/6.0f;
74 Ww[2][j] = vmulq_n_f32(
75 vsubq_f32(
76 vaddq_f32(
77 vsubq_f32(w[1][j], w[0][j]),
78 vsubq_f32(w[3][j], w[2][j])
79 ),
80 w[4][j]
81 ),
82 1.0f/6.0f
83 );
84
85 // Ww[3][j] = (w[0][j]/8.0f + w[1][j]/4.0f + w[2][j]/2.0f + w[3][j] + 2*w[4][j])/3.0f;
86 Ww[3][j] = vmulq_n_f32(
87 vmlaq_n_f32(
88 vaddq_f32(
89 vaddq_f32(vmulq_n_f32(w[0][j], 1.0f/8.0f), vmulq_n_f32(w[1][j], 1.0f/4.0f)),
90 vaddq_f32(vmulq_n_f32(w[2][j], 1.0f/2.0f), w[3][j])
91 ),
92 w[4][j], 2.0f
93 ),
94 1.0f/3.0f
95 );
96
97 // Ww[4][j] = (w[0][j]/8.0f - w[1][j]/4.0f + w[2][j]/2.0f - w[3][j] + 2*w[4][j])/3.0f;
98 Ww[4][j] = vmulq_n_f32(
99 vmlaq_n_f32(
100 vaddq_f32(
101 vsubq_f32(vmulq_n_f32(w[0][j], 1.0f/8.0f), vmulq_n_f32(w[1][j], 1.0f/4.0f)),
102 vsubq_f32(vmulq_n_f32(w[2][j], 1.0f/2.0f), w[3][j])
103 ),
104 w[4][j], 2.0f
105 ),
106 1.0f/3.0f
107 );
108
109 // Ww[5][j] = w[4][j];
110 Ww[5][j] = w[4][j];
111 }
112
113 // Compute V = W w WT
114 for (int i = 0; i < 6; i++)
115 {
116 // V[i][0] = Ww[i][0]/4.0f;
117 V[i][0] = vmulq_n_f32(Ww[i][0], 1.0f/4.0f);
118
119 // V[i][1] = -( Ww[i][0] + Ww[i][1] + Ww[i][2] + Ww[i][3] + Ww[i][4])/6.0f;
120 V[i][1] = vmulq_n_f32(
121 vaddq_f32(
122 vaddq_f32(
123 vaddq_f32(Ww[i][1], Ww[i][0]),
124 vaddq_f32(Ww[i][3], Ww[i][2])
125 ),
126 Ww[i][4]
127 ),
128 -1.0f/6.0f
129 );
130
131 // V[i][2] = +(-Ww[i][0] + Ww[i][1] - Ww[i][2] + Ww[i][3] - Ww[i][4])/6.0f;
132 // V[i][2] = ((Ww[i][1] - Ww[i][0]) + (Ww[i][3] - Ww[i][2]) - Ww[i][4])/6.0f;
133 V[i][2] = vmulq_n_f32(
134 vsubq_f32(
135 vaddq_f32(
136 vsubq_f32(Ww[i][1], Ww[i][0]),
137 vsubq_f32(Ww[i][3], Ww[i][2])
138 ),
139 Ww[i][4]
140 ),
141 1.0f/6.0f
142 );
143
144 // V[i][3] = (Ww[i][0]/8.0f + Ww[i][1]/4.0f + Ww[i][2]/2.0f + Ww[i][3] + 2*Ww[i][4])/3.0f;
145 V[i][3] = vmulq_n_f32(
146 vmlaq_n_f32(
147 vaddq_f32(
148 vaddq_f32(vmulq_n_f32(Ww[i][0], 1.0f/8.0f), vmulq_n_f32(Ww[i][1], 1.0f/4.0f)),
149 vaddq_f32(vmulq_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3])
150 ),
151 Ww[i][4], 2.0f
152 ),
153 1.0f/3.0f
154 );
155
156 // V[i][4] = (Ww[i][0]/8.0f - Ww[i][1]/4.0f + Ww[i][2]/2.0f - Ww[i][3] + 2*Ww[i][4])/3.0f;
157 V[i][4] = vmulq_n_f32(
158 vmlaq_n_f32(
159 vaddq_f32(
160 vsubq_f32(vmulq_n_f32(Ww[i][0], 1.0f/8.0f), vmulq_n_f32(Ww[i][1], 1.0f/4.0f)),
161 vsubq_f32(vmulq_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3])
162 ),
163 Ww[i][4], 2.0f
164 ),
165 1.0f/3.0f
166 );
167
168 // V[i][5] = Ww[i][4];
169 V[i][5] = Ww[i][4];
170 }
171
172 // Store the transformed weights
173 for (int i = 0, m = 0; i < 6; i++)
174 {
175 for (int j = 0; j < 6; j++, m++)
176 {
177 vst1q_f32(outptr + m*matrix_stride, V[i][j]);
178 }
179 }
180
181 inptr += 4;
182 outptr += 4;
183 }
184#endif // __aarch64__
185 for (; n_channels >= 2; n_channels -= 2)
186 {
187 // Matrices used and computed in this kernel
188 float32x2_t w[5][5], Ww[6][5], V[6][6];
189
190 // Read weights
191 for (int i = 0; i < 5; i++)
192 {
193 for (int j = 0; j < 5; j++)
194 {
195 w[i][j] = vld1_f32(inptr + i*ld_weight_row + j*ld_weight_col);
196 }
197 }
198
199 // Compute the matrix W w
200 for (int j = 0; j < 5; j++)
201 {
202 // Ww[0][j] = w[0][j]/4.0f;
203 Ww[0][j] = vmul_n_f32(w[0][j], 1.0f/4.0f);
204
205 // Ww[1][j] = -( w[0][j] + w[1][j] + w[2][j] + w[3][j] + w[4][j])/6.0f;
206 Ww[1][j] = vmul_n_f32(
207 vadd_f32(
208 vadd_f32(
209 vadd_f32(w[1][j], w[0][j]),
210 vadd_f32(w[3][j], w[2][j])
211 ),
212 w[4][j]
213 ),
214 -1.0f/6.0f
215 );
216
217 // Ww[2][j] = +(-w[0][j] + w[1][j] - w[2][j] + w[3][j] - w[4][j])/6.0f;
218 // Ww[2][j] = ((w[1][j] - w[0][j]) + (w[3][j] - w[2][j]) - w[4][j])/6.0f;
219 Ww[2][j] = vmul_n_f32(
220 vsub_f32(
221 vadd_f32(
222 vsub_f32(w[1][j], w[0][j]),
223 vsub_f32(w[3][j], w[2][j])
224 ),
225 w[4][j]
226 ),
227 1.0f/6.0f
228 );
229
230 // Ww[3][j] = (w[0][j]/8.0f + w[1][j]/4.0f + w[2][j]/2.0f + w[3][j] + 2*w[4][j])/3.0f;
231 Ww[3][j] = vmul_n_f32(
232 vmla_n_f32(
233 vadd_f32(
234 vadd_f32(vmul_n_f32(w[0][j], 1.0f/8.0f), vmul_n_f32(w[1][j], 1.0f/4.0f)),
235 vadd_f32(vmul_n_f32(w[2][j], 1.0f/2.0f), w[3][j])
236 ),
237 w[4][j], 2.0f
238 ),
239 1.0f/3.0f
240 );
241
242 // Ww[4][j] = (w[0][j]/8.0f - w[1][j]/4.0f + w[2][j]/2.0f - w[3][j] + 2*w[4][j])/3.0f;
243 Ww[4][j] = vmul_n_f32(
244 vmla_n_f32(
245 vadd_f32(
246 vsub_f32(vmul_n_f32(w[0][j], 1.0f/8.0f), vmul_n_f32(w[1][j], 1.0f/4.0f)),
247 vsub_f32(vmul_n_f32(w[2][j], 1.0f/2.0f), w[3][j])
248 ),
249 w[4][j], 2.0f
250 ),
251 1.0f/3.0f
252 );
253
254 // Ww[5][j] = w[4][j];
255 Ww[5][j] = w[4][j];
256 }
257
258 // Compute V = W w WT
259 for (int i = 0; i < 6; i++)
260 {
261 // V[i][0] = Ww[i][0]/4.0f;
262 V[i][0] = vmul_n_f32(Ww[i][0], 1.0f/4.0f);
263
264 // V[i][1] = -( Ww[i][0] + Ww[i][1] + Ww[i][2] + Ww[i][3] + Ww[i][4])/6.0f;
265 V[i][1] = vmul_n_f32(
266 vadd_f32(
267 vadd_f32(
268 vadd_f32(Ww[i][1], Ww[i][0]),
269 vadd_f32(Ww[i][3], Ww[i][2])
270 ),
271 Ww[i][4]
272 ),
273 -1.0f/6.0f
274 );
275
276 // V[i][2] = +(-Ww[i][0] + Ww[i][1] - Ww[i][2] + Ww[i][3] - Ww[i][4])/6.0f;
277 // V[i][2] = ((Ww[i][1] - Ww[i][0]) + (Ww[i][3] - Ww[i][2]) - Ww[i][4])/6.0f;
278 V[i][2] = vmul_n_f32(
279 vsub_f32(
280 vadd_f32(
281 vsub_f32(Ww[i][1], Ww[i][0]),
282 vsub_f32(Ww[i][3], Ww[i][2])
283 ),
284 Ww[i][4]
285 ),
286 1.0f/6.0f
287 );
288
289 // V[i][3] = (Ww[i][0]/8.0f + Ww[i][1]/4.0f + Ww[i][2]/2.0f + Ww[i][3] + 2*Ww[i][4])/3.0f;
290 V[i][3] = vmul_n_f32(
291 vmla_n_f32(
292 vadd_f32(
293 vadd_f32(vmul_n_f32(Ww[i][0], 1.0f/8.0f), vmul_n_f32(Ww[i][1], 1.0f/4.0f)),
294 vadd_f32(vmul_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3])
295 ),
296 Ww[i][4], 2.0f
297 ),
298 1.0f/3.0f
299 );
300
301 // V[i][4] = (Ww[i][0]/8.0f - Ww[i][1]/4.0f + Ww[i][2]/2.0f - Ww[i][3] + 2*Ww[i][4])/3.0f;
302 V[i][4] = vmul_n_f32(
303 vmla_n_f32(
304 vadd_f32(
305 vsub_f32(vmul_n_f32(Ww[i][0], 1.0f/8.0f), vmul_n_f32(Ww[i][1], 1.0f/4.0f)),
306 vsub_f32(vmul_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3])
307 ),
308 Ww[i][4], 2.0f
309 ),
310 1.0f/3.0f
311 );
312
313 // V[i][5] = Ww[i][4];
314 V[i][5] = Ww[i][4];
315 }
316
317 // Store the transformed weights
318 for (int i = 0, m = 0; i < 6; i++)
319 {
320 for (int j = 0; j < 6; j++, m++)
321 {
322 vst1_f32(outptr + m*matrix_stride, V[i][j]);
323 }
324 }
325
326 inptr += 2;
327 outptr += 2;
328 }
329 for (; n_channels; n_channels--)
330 {
331 // Matrices used and computed in this kernel
332 float w[5][5], Ww[6][5], V[6][6];
333
334 // Read weights
335 for (int i = 0; i < 5; i++)
336 {
337 for (int j = 0; j < 5; j++)
338 {
339 w[i][j] = *(inptr + i*ld_weight_row + j*ld_weight_col);
340 }
341 }
342
343 // Compute the matrix W w
344 for (int j = 0; j < 5; j++)
345 {
346 Ww[0][j] = w[0][j]/4.0f;
347 Ww[1][j] = -( w[0][j] + w[1][j] + w[2][j] + w[3][j] + w[4][j])/6.0f;
348 Ww[2][j] = +(-w[0][j] + w[1][j] - w[2][j] + w[3][j] - w[4][j])/6.0f;
349 Ww[3][j] = (w[0][j]/8.0f + w[1][j]/4.0f + w[2][j]/2.0f + w[3][j] + 2*w[4][j])/3.0f;
350 Ww[4][j] = (w[0][j]/8.0f - w[1][j]/4.0f + w[2][j]/2.0f - w[3][j] + 2*w[4][j])/3.0f;
351 Ww[5][j] = w[4][j];
352 }
353
354 // Compute V = W w WT
355 for (int i = 0; i < 6; i++)
356 {
357 V[i][0] = Ww[i][0]/4.0f;
358 V[i][1] = -( Ww[i][0] + Ww[i][1] + Ww[i][2] + Ww[i][3] + Ww[i][4])/6.0f;
359 V[i][2] = +(-Ww[i][0] + Ww[i][1] - Ww[i][2] + Ww[i][3] - Ww[i][4])/6.0f;
360 V[i][3] = (Ww[i][0]/8.0f + Ww[i][1]/4.0f + Ww[i][2]/2.0f + Ww[i][3] + 2*Ww[i][4])/3.0f;
361 V[i][4] = (Ww[i][0]/8.0f - Ww[i][1]/4.0f + Ww[i][2]/2.0f - Ww[i][3] + 2*Ww[i][4])/3.0f;
362 V[i][5] = Ww[i][4];
363 }
364
365 // Store the transformed weights
366 for (int i = 0, m = 0; i < 6; i++)
367 {
368 for (int j = 0; j < 6; j++, m++)
369 {
370 *(outptr + m*matrix_stride) = V[i][j];
371 }
372 }
373
374 inptr++;
375 outptr++;
376 }
377}
378
379} // namespace weight_transform
380} // namespace winograd
381} // namespace arm_conv