blob: 0d9a65890e6a5326c67296e67afe4bf4b37d2477 [file] [log] [blame]
Georgios Pinitas5ce897f2020-04-29 11:44:10 +01001/*
ramelg01a1f78512022-06-29 16:28:10 +01002 * Copyright (c) 2022 Arm Limited.
Georgios Pinitas5ce897f2020-04-29 11:44:10 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
ramelg01a1f78512022-06-29 16:28:10 +010024#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
Georgios Pinitas5ce897f2020-04-29 11:44:10 +010025
ramelg01a1f78512022-06-29 16:28:10 +010026#include <cstddef>
27#include <arm_neon.h>
Georgios Pinitas5ce897f2020-04-29 11:44:10 +010028
ramelg01a1f78512022-06-29 16:28:10 +010029namespace arm_conv {
30namespace winograd {
31namespace weight_transform {
Georgios Pinitas5ce897f2020-04-29 11:44:10 +010032
ramelg01a1f78512022-06-29 16:28:10 +010033void a64_fp16_4x4_3x3(
34 unsigned int n_channels,
35 const __fp16* inptr, // NOTE: Data in HWIO order
36 const size_t ld_weight_row,
37 const size_t ld_weight_col,
38 __fp16* outptr,
39 const size_t matrix_stride
Georgios Pinitas5ce897f2020-04-29 11:44:10 +010040)
41{
Georgios Pinitas5ce897f2020-04-29 11:44:10 +010042#ifdef __aarch64__
ramelg01a1f78512022-06-29 16:28:10 +010043 for (; n_channels >= 8; n_channels -= 8)
Georgios Pinitas5ce897f2020-04-29 11:44:10 +010044 {
45 // Matrices used and computed in this kernel
46 float16x8_t w[3][3], Ww[6][3], V[6][6];
47
48 // Read weights
49 for (int i = 0; i < 3; i++)
50 {
51 for (int j = 0; j < 3; j++)
52 {
ramelg01a1f78512022-06-29 16:28:10 +010053 w[i][j] = vld1q_f16(inptr + i*ld_weight_row + j*ld_weight_col);
Georgios Pinitas5ce897f2020-04-29 11:44:10 +010054 }
55 }
56
57 // Compute the matrix W w
58 for (int j = 0; j < 3; j++)
59 {
60 // Ww[0][j] = 6*w[0][j];
61 Ww[0][j] = vmulq_n_f16(w[0][j], 6.0);
62
63 // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
64 Ww[1][j] = vmulq_n_f16(vaddq_f16(vaddq_f16(w[0][j], w[1][j]), w[2][j]), -4.0);
65
66 // Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j];
67 Ww[2][j] = vmulq_n_f16(vsubq_f16(vsubq_f16(w[1][j], w[0][j]), w[2][j]), 4.0);
68
69 // Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j];
70 Ww[3][j] = vaddq_f16(vaddq_f16(w[0][j], vmulq_f16(w[1][j], vdupq_n_f16(2.0f))), vmulq_f16(w[2][j], vdupq_n_f16(4.0f)));
71
72 // Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j];
73 Ww[4][j] = vaddq_f16(vsubq_f16(w[0][j], vmulq_f16(w[1][j], vdupq_n_f16(2.0f))), vmulq_f16(w[2][j], vdupq_n_f16(4.0f)));
74
75 // Ww[5][j] = 24*w[2][j];
76 Ww[5][j] = vmulq_n_f16(w[2][j], 24.0f);
77 }
78
79 // Compute V = W w WT
80 for (int i = 0; i < 6; i++)
81 {
82 const float recip576 = 1.0f / 576.0f;
83
84 // V[i][0] = 6*Ww[i][0];
85 V[i][0] = vmulq_n_f16(vmulq_n_f16(Ww[i][0], 6.0), recip576);
86
87 // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2];
88 V[i][1] = vmulq_n_f16(vmulq_n_f16(vaddq_f16(vaddq_f16(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576);
89
90 // V[i][2] = -4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2];
91 V[i][2] = vmulq_n_f16(vmulq_n_f16(vsubq_f16(vsubq_f16(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576);
92
93 // V[i][3] = 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2];
94 V[i][3] = vmulq_n_f16(vaddq_f16(vaddq_f16(Ww[i][0], vmulq_f16(Ww[i][1], vdupq_n_f16(2.0f))), vmulq_f16(Ww[i][2], vdupq_n_f16(4.0f))), recip576);
95
96 // V[i][4] = 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2];
97 V[i][4] = vmulq_n_f16(vaddq_f16(vsubq_f16(Ww[i][0], vmulq_f16(Ww[i][1], vdupq_n_f16(2.0f))), vmulq_f16(Ww[i][2], vdupq_n_f16(4.0f))), recip576);
98
99 // V[i][5] = 24*Ww[i][2];
100 V[i][5] = vmulq_n_f16(vmulq_n_f16(Ww[i][2], 24.0f), recip576);
101 }
102
103 // Store the transformed weights
104 for (int i = 0, m = 0; i < 6; i++)
105 {
106 for (int j = 0; j < 6; j++, m++)
107 {
108 vst1q_f16(outptr + m*matrix_stride, V[i][j]);
109 }
110 }
ramelg01a1f78512022-06-29 16:28:10 +0100111 inptr += 8;
Georgios Pinitas5ce897f2020-04-29 11:44:10 +0100112 outptr += 8;
113 }
114#endif // __aarch64__
115#ifdef __arm_any__
ramelg01a1f78512022-06-29 16:28:10 +0100116 for (; n_channels >= 4; n_channels -= 4)
Georgios Pinitas5ce897f2020-04-29 11:44:10 +0100117 {
118 // Matrices used and computed in this kernel
119 float16x4_t w[3][3], Ww[6][3], V[6][6];
120
121 // Read weights
122 for (int i = 0; i < 3; i++)
123 {
124 for (int j = 0; j < 3; j++)
125 {
ramelg01a1f78512022-06-29 16:28:10 +0100126 w[i][j] = vld1_f16(inptr + i*ld_weight_row + j*ld_weight_col);
Georgios Pinitas5ce897f2020-04-29 11:44:10 +0100127 }
128 }
129
130 // Compute the matrix W w
131 for (int j = 0; j < 3; j++)
132 {
133 // Ww[0][j] = 6*w[0][j];
134 Ww[0][j] = vmul_n_f16(w[0][j], 6.0);
135
136 // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
137 Ww[1][j] = vmul_n_f16(vadd_f16(vadd_f16(w[0][j], w[1][j]), w[2][j]), -4.0);
138
139 // Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j];
140 Ww[2][j] = vmul_n_f16(vsub_f16(vsub_f16(w[1][j], w[0][j]), w[2][j]), 4.0);
141
142 // Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j];
143 Ww[3][j] = vadd_f16(vadd_f16(w[0][j], vmul_f16(w[1][j], vdup_n_f16(2.0f))), vmul_f16(w[2][j], vdup_n_f16(4.0f)));
144
145 // Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j];
146 Ww[4][j] = vadd_f16(vsub_f16(w[0][j], vmul_f16(w[1][j], vdup_n_f16(2.0f))), vmul_f16(w[2][j], vdup_n_f16(4.0f)));
147
148 // Ww[5][j] = 24*w[2][j];
149 Ww[5][j] = vmul_n_f16(w[2][j], 24.0f);
150 }
151
152 // Compute V = W w WT
153 for (int i = 0; i < 6; i++)
154 {
155 const float recip576 = 1.0f / 576.0f;
156
157 // V[i][0] = 6*Ww[i][0];
158 V[i][0] = vmul_n_f16(vmul_n_f16(Ww[i][0], 6.0), recip576);
159
160 // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2];
161 V[i][1] = vmul_n_f16(vmul_n_f16(vadd_f16(vadd_f16(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576);
162
163 // V[i][2] = -4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2];
164 V[i][2] = vmul_n_f16(vmul_n_f16(vsub_f16(vsub_f16(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576);
165
166 // V[i][3] = 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2];
167 V[i][3] = vmul_n_f16(vadd_f16(vadd_f16(Ww[i][0], vmul_f16(Ww[i][1], vdup_n_f16(2.0f))), vmul_f16(Ww[i][2], vdup_n_f16(4.0f))), recip576);
168
169 // V[i][4] = 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2];
170 V[i][4] = vmul_n_f16(vadd_f16(vsub_f16(Ww[i][0], vmul_f16(Ww[i][1], vdup_n_f16(2.0f))), vmul_f16(Ww[i][2], vdup_n_f16(4.0f))), recip576);
171
172 // V[i][5] = 24*Ww[i][2];
173 V[i][5] = vmul_n_f16(vmul_n_f16(Ww[i][2], 24.0f), recip576);
174 }
175
176 // Store the transformed weights
177 for (int i = 0, m = 0; i < 6; i++)
178 {
179 for (int j = 0; j < 6; j++, m++)
180 {
181 vst1_f16(outptr + m*matrix_stride, V[i][j]);
182 }
183 }
ramelg01a1f78512022-06-29 16:28:10 +0100184 inptr += 4;
Georgios Pinitas5ce897f2020-04-29 11:44:10 +0100185 outptr += 4;
186 }
187#endif // __arm_any__
ramelg01a1f78512022-06-29 16:28:10 +0100188 for (; n_channels; n_channels--)
189 {
190 // Matrices used and computed in this kernel
191 __fp16 w[3][3], Ww[6][3], V[6][6];
192
193 // Read weights
194 for (int i = 0; i < 3; i++)
195 {
196 for (int j = 0; j < 3; j++)
Georgios Pinitas5ce897f2020-04-29 11:44:10 +0100197 {
ramelg01a1f78512022-06-29 16:28:10 +0100198 w[i][j] = *(inptr + i*ld_weight_row + j*ld_weight_col);
Georgios Pinitas5ce897f2020-04-29 11:44:10 +0100199 }
ramelg01a1f78512022-06-29 16:28:10 +0100200 }
201
202 // Compute the matrix W w
203 for (int j = 0; j < 3; j++)
204 {
205 Ww[0][j] = 6*w[0][j];
206 Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
207 Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j];
208 Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j];
209 Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j];
210 Ww[5][j] = 24*w[2][j];
211 }
212
213 // Compute V = W w WT
214 for (int i = 0; i < 6; i++)
215 {
216 V[i][0] = ( 6*Ww[i][0]) / 576.0;
217 V[i][1] = (-4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]) / 576.0;
218 V[i][2] = (-4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]) / 576.0;
219 V[i][3] = ( 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]) / 576.0;
220 V[i][4] = ( 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]) / 576.0;
221 V[i][5] = (24*Ww[i][2]) / 576.0;
222 }
223
224 // Store the transformed weights
225 for (int i = 0, m = 0; i < 6; i++)
226 {
227 for (int j = 0; j < 6; j++, m++)
228 {
229 *(outptr + m*matrix_stride) = V[i][j];
230 }
231 }
232
233 inptr++;
234 outptr++;
Georgios Pinitas5ce897f2020-04-29 11:44:10 +0100235 }
236}
237
ramelg01a1f78512022-06-29 16:28:10 +0100238} // namespace weight_transform
239} // namespace winograd
240} // namespace arm_conv
Georgios Pinitas5ce897f2020-04-29 11:44:10 +0100241
ramelg01a1f78512022-06-29 16:28:10 +0100242#endif // defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)