blob: aad88caff80675f4e55aace9acb7c0a50f410897 [file] [log] [blame]
ramelg01a1f78512022-06-29 16:28:10 +01001/*
Viet-Hoa Dobb1ab052022-12-23 14:48:33 +00002 * Copyright (c) 2022 Arm Limited.
ramelg01a1f78512022-06-29 16:28:10 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include <cstddef>
26#include <arm_neon.h>
27
28namespace arm_conv {
29namespace winograd {
30namespace weight_transform {
31
32void arm_fp32_4x4_3x3(
33 unsigned int n_channels,
34 const float *inptr, const size_t ld_weight_row, const size_t ld_weight_col,
35 float *outptr, const size_t matrix_stride
36)
37{
38#ifdef __aarch64__
39 for (; n_channels >= 4; n_channels -= 4)
40 {
41 // Matrices used and computed in this kernel
42 float32x4_t w[3][3], Ww[6][3], V[6][6];
43
44 // Read weights
45 for (int i = 0; i < 3; i++)
46 {
47 for (int j = 0; j < 3; j++)
48 {
49 w[i][j] = vld1q_f32(inptr + i*ld_weight_row + j*ld_weight_col);
50 }
51 }
52
53 // Compute the matrix W w
54 for (int j = 0; j < 3; j++)
55 {
56 // Ww[0][j] = 6*w[0][j];
57 Ww[0][j] = vmulq_n_f32(w[0][j], 6.0);
58
59 // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
60 Ww[1][j] = vmulq_n_f32(vaddq_f32(vaddq_f32(w[0][j], w[1][j]), w[2][j]), -4.0);
61
62 // Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j];
63 Ww[2][j] = vmulq_n_f32(vsubq_f32(vsubq_f32(w[1][j], w[0][j]), w[2][j]), 4.0);
64
65 // Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j];
66 Ww[3][j] = vmlaq_n_f32(vmlaq_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f);
67
68 // Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j];
69 Ww[4][j] = vmlaq_n_f32(vmlsq_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f);
70
71 // Ww[5][j] = 24*w[2][j];
72 Ww[5][j] = vmulq_n_f32(w[2][j], 24.0f);
73 }
74
75 // Compute V = W w WT
76 for (int i = 0; i < 6; i++)
77 {
78 const float recip576 = 1.0f / 576.0f;
79
80 // V[i][0] = 6*Ww[i][0];
81 V[i][0] = vmulq_n_f32(vmulq_n_f32(Ww[i][0], 6.0), recip576);
82
83 // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2];
84 V[i][1] = vmulq_n_f32(vmulq_n_f32(vaddq_f32(vaddq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576);
85
86 // V[i][2] = -4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2];
87 V[i][2] = vmulq_n_f32(vmulq_n_f32(vsubq_f32(vsubq_f32(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576);
88
89 // V[i][3] = 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2];
90 V[i][3] = vmulq_n_f32(vmlaq_n_f32(vmlaq_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576);
91
92 // V[i][4] = 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2];
93 V[i][4] = vmulq_n_f32(vmlaq_n_f32(vmlsq_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576);
94
95 // V[i][5] = 24*Ww[i][2];
96 V[i][5] = vmulq_n_f32(vmulq_n_f32(Ww[i][2], 24.0f), recip576);
97 }
98
99 // Store the transformed weights
100 for (int i = 0, m = 0; i < 6; i++)
101 {
102 for (int j = 0; j < 6; j++, m++)
103 {
104 vst1q_f32(outptr + m*matrix_stride, V[i][j]);
105 }
106 }
107
108 inptr += 4;
109 outptr += 4;
110 }
111#endif // __aarch64__
112 for (; n_channels >= 2; n_channels -= 2)
113 {
114 // Matrices used and computed in this kernel
115 float32x2_t w[3][3], Ww[6][3], V[6][6];
116
117 // Read weights
118 for (int i = 0; i < 3; i++)
119 {
120 for (int j = 0; j < 3; j++)
121 {
122 w[i][j] = vld1_f32(inptr + i*ld_weight_row + j*ld_weight_col);
123 }
124 }
125
126 // Compute the matrix W w
127 for (int j = 0; j < 3; j++)
128 {
129 // Ww[0][j] = 6*w[0][j];
130 Ww[0][j] = vmul_n_f32(w[0][j], 6.0);
131
132 // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
133 Ww[1][j] = vmul_n_f32(vadd_f32(vadd_f32(w[0][j], w[1][j]), w[2][j]), -4.0);
134
135 // Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j];
136 Ww[2][j] = vmul_n_f32(vsub_f32(vsub_f32(w[1][j], w[0][j]), w[2][j]), 4.0);
137
138 // Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j];
139 Ww[3][j] = vmla_n_f32(vmla_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f);
140
141 // Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j];
142 Ww[4][j] = vmla_n_f32(vmls_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f);
143
144 // Ww[5][j] = 24*w[2][j];
145 Ww[5][j] = vmul_n_f32(w[2][j], 24.0f);
146 }
147
148 // Compute V = W w WT
149 for (int i = 0; i < 6; i++)
150 {
151 const float recip576 = 1.0f / 576.0f;
152
153 // V[i][0] = 6*Ww[i][0];
154 V[i][0] = vmul_n_f32(vmul_n_f32(Ww[i][0], 6.0), recip576);
155
156 // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2];
157 V[i][1] = vmul_n_f32(vmul_n_f32(vadd_f32(vadd_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576);
158
159 // V[i][2] = -4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2];
160 V[i][2] = vmul_n_f32(vmul_n_f32(vsub_f32(vsub_f32(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576);
161
162 // V[i][3] = 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2];
163 V[i][3] = vmul_n_f32(vmla_n_f32(vmla_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576);
164
165 // V[i][4] = 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2];
166 V[i][4] = vmul_n_f32(vmla_n_f32(vmls_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576);
167
168 // V[i][5] = 24*Ww[i][2];
169 V[i][5] = vmul_n_f32(vmul_n_f32(Ww[i][2], 24.0f), recip576);
170 }
171
172 // Store the transformed weights
173 for (int i = 0, m = 0; i < 6; i++)
174 {
175 for (int j = 0; j < 6; j++, m++)
176 {
177 vst1_f32(outptr + m*matrix_stride, V[i][j]);
178 }
179 }
180
181 inptr += 2;
182 outptr += 2;
183 }
184 for (; n_channels; n_channels--)
185 {
186 // Matrices used and computed in this kernel
187 float w[3][3], Ww[6][3], V[6][6];
188
189 // Read weights
190 for (int i = 0; i < 3; i++)
191 {
192 for (int j = 0; j < 3; j++)
193 {
194 w[i][j] = *(inptr + i*ld_weight_row + j*ld_weight_col);
195 }
196 }
197
198 // Compute the matrix W w
199 for (int j = 0; j < 3; j++)
200 {
201 Ww[0][j] = 6*w[0][j];
202 Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
203 Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j];
204 Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j];
205 Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j];
206 Ww[5][j] = 24*w[2][j];
207 }
208
209 // Compute V = W w WT
210 for (int i = 0; i < 6; i++)
211 {
212 V[i][0] = ( 6*Ww[i][0]) / 576.0;
213 V[i][1] = (-4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]) / 576.0;
214 V[i][2] = (-4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]) / 576.0;
215 V[i][3] = ( 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]) / 576.0;
216 V[i][4] = ( 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]) / 576.0;
217 V[i][5] = (24*Ww[i][2]) / 576.0;
218 }
219
220 // Store the transformed weights
221 for (int i = 0, m = 0; i < 6; i++)
222 {
223 for (int j = 0; j < 6; j++, m++)
224 {
225 *(outptr + m*matrix_stride) = V[i][j];
226 }
227 }
228
229 inptr++;
230 outptr++;
231 }
232}
233
234} // namespace weight_transform
235} // namespace winograd
236} // namespace arm_conv