blob: 9393785dfc2b72381864a079780a9bd34f1b1373 [file] [log] [blame]
Pablo Tello8f43d742019-03-27 09:28:32 +00001/*
2 * Copyright (c) 2019 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "input.hpp"
26#include "arm.hpp"
27
28namespace winograd
29{
30
31template <>
32void InputTransform<4, 4, float, float, WinogradRoots::Integers>::transform_tile(
33 const int n_channels,
34 const float* const input_base,
35 const int input_row_stride,
36 const int input_col_stride,
37 float* outptr,
38 const int matrix_stride
39)
40{
41 constexpr int inner_tile_rows = 4, inner_tile_cols = 4;
42
43 // Get pointers into the input tile
44 const float *x_ptrs[inner_tile_rows][inner_tile_cols];
45 for (int i = 0, xi = 0; i < inner_tile_rows; i++, xi++)
46 {
47 // Get a pointer into the row
48 const float* const row_ptr = input_base + xi*input_row_stride;
49
50 for (int j = 0, xj = 0; j < inner_tile_cols; j++, xj++)
51 {
52 x_ptrs[i][j] = row_ptr + xj*input_col_stride;
53 }
54 }
55
56 // Matrices used/computed in this kernel.
57 float x[inner_tile_rows][inner_tile_cols];
58 float XTx[inner_tile_rows][inner_tile_cols];
59 float U[inner_tile_rows][inner_tile_cols];
60
61 for (int i = 0; i < inner_tile_rows; i++)
62 {
63 for (int j = 0; j < inner_tile_cols; j++)
64 {
65 x[i][j] = XTx[i][j] = 0.0f;
66 }
67 }
68
69 // Perform the Winograd input transformation for each channel in the input
70 // tensor.
71 int channels_remaining = n_channels;
72#ifdef __aarch64__
73 for (; channels_remaining >= 4; channels_remaining -= 4)
74 {
75 // Matrices used/computed in this kernel.
76 float32x4_t x[inner_tile_rows][inner_tile_cols];
77 float32x4_t XTx[inner_tile_rows][inner_tile_cols];
78 float32x4_t U[inner_tile_rows][inner_tile_cols];
79
80 for (int i = 0; i < inner_tile_rows; i++)
81 {
82 for (int j = 0; j < inner_tile_cols; j++)
83 {
84 x[i][j] = vdupq_n_f32(0.0f);
85 XTx[i][j] = vdupq_n_f32(0.0f);
86 }
87 }
88
89 // Load x
90 for (int i = 0; i < inner_tile_rows; i++)
91 {
92 for (int j = 0; j < inner_tile_cols; j++)
93 {
94 x[i][j] = vld1q_f32(x_ptrs[i][j]);
95 x_ptrs[i][j] += 4;
96 }
97 }
98
99 // Compute XT . x
100 for (int j = 0; j < inner_tile_cols; j++)
101 {
102 // XTx[0][j] = x[0][j] - x[2][j];
103 XTx[0][j] = vsubq_f32(x[0][j], x[2][j]);
104
105 // XTx[1][j] = x[1][j] + x[2][j];
106 XTx[1][j] = vaddq_f32(x[1][j], x[2][j]);
107
108 // XTx[2][j] = x[2][j] - x[1][j];
109 XTx[2][j] = vsubq_f32(x[2][j], x[1][j]);
110
111 // XTx[3][j] = x[1][j] - x[3][j];
112 XTx[3][j] = vsubq_f32(x[1][j], x[3][j]);
113 }
114
115 // Compute U = XT . x . X
116 for (int i = 0; i < inner_tile_rows; i++)
117 {
118 // U[i][0] = XTx[i][0] - XTx[i][2];
119 U[i][0] = vsubq_f32(XTx[i][0], XTx[i][2]);
120
121 // U[i][1] = XTx[i][1] + XTx[i][2];
122 U[i][1] = vaddq_f32(XTx[i][1], XTx[i][2]);
123
124 // U[i][2] = XTx[i][2] - XTx[i][1];
125 U[i][2] = vsubq_f32(XTx[i][2], XTx[i][1]);
126
127 // U[i][3] = XTx[i][1] - XTx[i][3];
128 U[i][3] = vsubq_f32(XTx[i][1], XTx[i][3]);
129 }
130
131 // Store the transformed matrix
132 for (int i = 0, m = 0; i < inner_tile_rows; i++)
133 {
134 for (int j = 0; j < inner_tile_cols; j++, m++)
135 {
136 vst1q_f32(outptr + m*matrix_stride, U[i][j]);
137 }
138 }
139 outptr += 4;
140 }
141#endif // __aarch64__
142#ifdef __arm_any__
143 for (; channels_remaining >= 2; channels_remaining -= 2)
144 {
145 // Matrices used/computed in this kernel.
146 float32x2_t x[inner_tile_rows][inner_tile_cols];
147 float32x2_t XTx[inner_tile_rows][inner_tile_cols];
148 float32x2_t U[inner_tile_rows][inner_tile_cols];
149
150 for (int i = 0; i < inner_tile_rows; i++)
151 {
152 for (int j = 0; j < inner_tile_cols; j++)
153 {
154 x[i][j] = vdup_n_f32(0.0f);
155 XTx[i][j] = vdup_n_f32(0.0f);
156 }
157 }
158
159 // Load x
160 for (int i = 0; i < inner_tile_rows; i++)
161 {
162 for (int j = 0; j < inner_tile_cols; j++)
163 {
164 x[i][j] = vld1_f32(x_ptrs[i][j]);
165 x_ptrs[i][j] += 2;
166 }
167 }
168
169 // Compute XT . x
170 for (int j = 0; j < inner_tile_cols; j++)
171 {
172 // XTx[0][j] = x[0][j] - x[2][j];
173 XTx[0][j] = vsub_f32(x[0][j], x[2][j]);
174
175 // XTx[1][j] = x[1][j] + x[2][j];
176 XTx[1][j] = vadd_f32(x[1][j], x[2][j]);
177
178 // XTx[2][j] = x[2][j] - x[1][j];
179 XTx[2][j] = vsub_f32(x[2][j], x[1][j]);
180
181 // XTx[3][j] = x[1][j] - x[3][j];
182 XTx[3][j] = vsub_f32(x[1][j], x[3][j]);
183 }
184
185 // Compute U = XT . x . X
186 for (int i = 0; i < inner_tile_rows; i++)
187 {
188 // U[i][0] = XTx[i][0] - XTx[i][2];
189 U[i][0] = vsub_f32(XTx[i][0], XTx[i][2]);
190
191 // U[i][1] = XTx[i][1] + XTx[i][2];
192 U[i][1] = vadd_f32(XTx[i][1], XTx[i][2]);
193
194 // U[i][2] = XTx[i][2] - XTx[i][1];
195 U[i][2] = vsub_f32(XTx[i][2], XTx[i][1]);
196
197 // U[i][3] = XTx[i][1] - XTx[i][3];
198 U[i][3] = vsub_f32(XTx[i][1], XTx[i][3]);
199 }
200
201 // Store the transformed matrix
202 for (int i = 0, m = 0; i < inner_tile_rows; i++)
203 {
204 for (int j = 0; j < inner_tile_cols; j++, m++)
205 {
206 vst1_f32(outptr + m*matrix_stride, U[i][j]);
207 }
208 }
209 outptr += 2;
210 }
211#endif // __arm_any__
212 for (; channels_remaining; channels_remaining--)
213 {
214 // Load x
215 for (int i = 0; i < inner_tile_rows; i++)
216 {
217 for (int j = 0; j < inner_tile_cols; j++)
218 {
219 x[i][j] = *(x_ptrs[i][j]++);
220 }
221 }
222
223 // Compute XT . x
224 for (int j = 0; j < inner_tile_cols; j++)
225 {
226 XTx[0][j] = x[0][j] - x[2][j];
227 XTx[1][j] = x[1][j] + x[2][j];
228 XTx[2][j] = x[2][j] - x[1][j];
229 XTx[3][j] = x[1][j] - x[3][j];
230 }
231
232 // Compute U = XT . x . X
233 for (int i = 0; i < inner_tile_rows; i++)
234 {
235 U[i][0] = XTx[i][0] - XTx[i][2];
236 U[i][1] = XTx[i][1] + XTx[i][2];
237 U[i][2] = XTx[i][2] - XTx[i][1];
238 U[i][3] = XTx[i][1] - XTx[i][3];
239 }
240
241 // Store the transformed matrix
242 for (int i = 0, m = 0; i < inner_tile_rows; i++)
243 {
244 for (int j = 0; j < inner_tile_cols; j++, m++)
245 {
246 *(outptr + m*matrix_stride) = U[i][j];
247 }
248 }
249 outptr++;
250 }
251}
252
253template class InputTransform<4, 4, float, float, WinogradRoots::Integers>;
254
255} // namespace