blob: 1ea68b593855f658f1ae706d99d8946483f7de51 [file] [log] [blame]
Georgios Pinitas5ce897f2020-04-29 11:44:10 +01001/*
2 * Copyright (c) 2020 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
25
26#include "input.hpp"
27#include "arm.hpp"
28
29namespace winograd
30{
31
32template <>
33void InputTransform<4, 4, __fp16, __fp16, WinogradRoots::Integers>::transform_tile(
34 const int n_channels,
35 const __fp16* const input_base,
36 const int input_row_stride,
37 const int input_col_stride,
38 __fp16* outptr,
39 const int matrix_stride
40)
41{
42 constexpr int inner_tile_rows = 4, inner_tile_cols = 4;
43
44 // Get pointers into the input tile
45 const __fp16 *x_ptrs[inner_tile_rows][inner_tile_cols];
46 for (int i = 0, xi = 0; i < inner_tile_rows; i++, xi++)
47 {
48 // Get a pointer into the row
49 const __fp16* const row_ptr = input_base + xi*input_row_stride;
50
51 for (int j = 0, xj = 0; j < inner_tile_cols; j++, xj++)
52 {
53 x_ptrs[i][j] = row_ptr + xj*input_col_stride;
54 }
55 }
56
57 // Matrices used/computed in this kernel.
58 __fp16 x[inner_tile_rows][inner_tile_cols];
59 __fp16 XTx[inner_tile_rows][inner_tile_cols];
60 __fp16 U[inner_tile_rows][inner_tile_cols];
61
62 for (int i = 0; i < inner_tile_rows; i++)
63 {
64 for (int j = 0; j < inner_tile_cols; j++)
65 {
66 x[i][j] = XTx[i][j] = 0.0f;
67 }
68 }
69
70 // Perform the Winograd input transformation for each channel in the input
71 // tensor.
72 int channels_remaining = n_channels;
73#ifdef __aarch64__
74 for (; channels_remaining >= 8; channels_remaining -= 8)
75 {
76 // Matrices used/computed in this kernel.
77 float16x8_t x[inner_tile_rows][inner_tile_cols];
78 float16x8_t XTx[inner_tile_rows][inner_tile_cols];
79 float16x8_t U[inner_tile_rows][inner_tile_cols];
80
81 for (int i = 0; i < inner_tile_rows; i++)
82 {
83 for (int j = 0; j < inner_tile_cols; j++)
84 {
85 x[i][j] = vdupq_n_f16(0.0f);
86 XTx[i][j] = vdupq_n_f16(0.0f);
87 }
88 }
89
90 // Load x
91 for (int i = 0; i < inner_tile_rows; i++)
92 {
93 for (int j = 0; j < inner_tile_cols; j++)
94 {
95 x[i][j] = vld1q_f16(x_ptrs[i][j]);
96 x_ptrs[i][j] += 8;
97 }
98 }
99
100 // Compute XT . x
101 for (int j = 0; j < inner_tile_cols; j++)
102 {
103 // XTx[0][j] = x[0][j] - x[2][j];
104 XTx[0][j] = vsubq_f16(x[0][j], x[2][j]);
105
106 // XTx[1][j] = x[1][j] + x[2][j];
107 XTx[1][j] = vaddq_f16(x[1][j], x[2][j]);
108
109 // XTx[2][j] = x[2][j] - x[1][j];
110 XTx[2][j] = vsubq_f16(x[2][j], x[1][j]);
111
112 // XTx[3][j] = x[1][j] - x[3][j];
113 XTx[3][j] = vsubq_f16(x[1][j], x[3][j]);
114 }
115
116 // Compute U = XT . x . X
117 for (int i = 0; i < inner_tile_rows; i++)
118 {
119 // U[i][0] = XTx[i][0] - XTx[i][2];
120 U[i][0] = vsubq_f16(XTx[i][0], XTx[i][2]);
121
122 // U[i][1] = XTx[i][1] + XTx[i][2];
123 U[i][1] = vaddq_f16(XTx[i][1], XTx[i][2]);
124
125 // U[i][2] = XTx[i][2] - XTx[i][1];
126 U[i][2] = vsubq_f16(XTx[i][2], XTx[i][1]);
127
128 // U[i][3] = XTx[i][1] - XTx[i][3];
129 U[i][3] = vsubq_f16(XTx[i][1], XTx[i][3]);
130 }
131
132 // Store the transformed matrix
133 for (int i = 0, m = 0; i < inner_tile_rows; i++)
134 {
135 for (int j = 0; j < inner_tile_cols; j++, m++)
136 {
137 vst1q_f16(outptr + m*matrix_stride, U[i][j]);
138 }
139 }
140 outptr += 8;
141 }
142#endif // __aarch64__
143#ifdef __arm_any__
144 for (; channels_remaining >= 4; channels_remaining -= 4)
145 {
146 // Matrices used/computed in this kernel.
147 float16x4_t x[inner_tile_rows][inner_tile_cols];
148 float16x4_t XTx[inner_tile_rows][inner_tile_cols];
149 float16x4_t U[inner_tile_rows][inner_tile_cols];
150
151 for (int i = 0; i < inner_tile_rows; i++)
152 {
153 for (int j = 0; j < inner_tile_cols; j++)
154 {
155 x[i][j] = vdup_n_f16(0.0f);
156 XTx[i][j] = vdup_n_f16(0.0f);
157 }
158 }
159
160 // Load x
161 for (int i = 0; i < inner_tile_rows; i++)
162 {
163 for (int j = 0; j < inner_tile_cols; j++)
164 {
165 x[i][j] = vld1_f16(x_ptrs[i][j]);
166 x_ptrs[i][j] += 4;
167 }
168 }
169
170 // Compute XT . x
171 for (int j = 0; j < inner_tile_cols; j++)
172 {
173 // XTx[0][j] = x[0][j] - x[2][j];
174 XTx[0][j] = vsub_f16(x[0][j], x[2][j]);
175
176 // XTx[1][j] = x[1][j] + x[2][j];
177 XTx[1][j] = vadd_f16(x[1][j], x[2][j]);
178
179 // XTx[2][j] = x[2][j] - x[1][j];
180 XTx[2][j] = vsub_f16(x[2][j], x[1][j]);
181
182 // XTx[3][j] = x[1][j] - x[3][j];
183 XTx[3][j] = vsub_f16(x[1][j], x[3][j]);
184 }
185
186 // Compute U = XT . x . X
187 for (int i = 0; i < inner_tile_rows; i++)
188 {
189 // U[i][0] = XTx[i][0] - XTx[i][2];
190 U[i][0] = vsub_f16(XTx[i][0], XTx[i][2]);
191
192 // U[i][1] = XTx[i][1] + XTx[i][2];
193 U[i][1] = vadd_f16(XTx[i][1], XTx[i][2]);
194
195 // U[i][2] = XTx[i][2] - XTx[i][1];
196 U[i][2] = vsub_f16(XTx[i][2], XTx[i][1]);
197
198 // U[i][3] = XTx[i][1] - XTx[i][3];
199 U[i][3] = vsub_f16(XTx[i][1], XTx[i][3]);
200 }
201
202 // Store the transformed matrix
203 for (int i = 0, m = 0; i < inner_tile_rows; i++)
204 {
205 for (int j = 0; j < inner_tile_cols; j++, m++)
206 {
207 vst1_f16(outptr + m*matrix_stride, U[i][j]);
208 }
209 }
210 outptr += 4;
211 }
212#endif // __arm_any__
213 for (; channels_remaining; channels_remaining--)
214 {
215 // Load x
216 for (int i = 0; i < inner_tile_rows; i++)
217 {
218 for (int j = 0; j < inner_tile_cols; j++)
219 {
220 x[i][j] = *(x_ptrs[i][j]++);
221 }
222 }
223
224 // Compute XT . x
225 for (int j = 0; j < inner_tile_cols; j++)
226 {
227 XTx[0][j] = x[0][j] - x[2][j];
228 XTx[1][j] = x[1][j] + x[2][j];
229 XTx[2][j] = x[2][j] - x[1][j];
230 XTx[3][j] = x[1][j] - x[3][j];
231 }
232
233 // Compute U = XT . x . X
234 for (int i = 0; i < inner_tile_rows; i++)
235 {
236 U[i][0] = XTx[i][0] - XTx[i][2];
237 U[i][1] = XTx[i][1] + XTx[i][2];
238 U[i][2] = XTx[i][2] - XTx[i][1];
239 U[i][3] = XTx[i][1] - XTx[i][3];
240 }
241
242 // Store the transformed matrix
243 for (int i = 0, m = 0; i < inner_tile_rows; i++)
244 {
245 for (int j = 0; j < inner_tile_cols; j++, m++)
246 {
247 *(outptr + m*matrix_stride) = U[i][j];
248 }
249 }
250 outptr++;
251 }
252}
253
254template class InputTransform<4, 4, __fp16, __fp16, WinogradRoots::Integers>;
255
256} // namespace
257#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC