blob: 8fab6db1baeec97faba066c6a9afc37185940829 [file] [log] [blame]
Pablo Tello8f43d742019-03-27 09:28:32 +00001/*
2 * Copyright (c) 2019 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "arm.hpp"
26#include "kernel.hpp"
27
28namespace winograd
29{
30
31template <>
32void WeightTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::execute(
33 const int n_output_channels,
34 const int n_input_channels,
35 const float* const input,
36 float* const output,
37 const int matrix_stride,
38 const int matrix_row_stride
39)
40{
41 constexpr int inner_tile_i = 4;
42 constexpr int inner_tile_j = 4;
43
44 // Get pointers to each cell of the weight tensor
45 const auto weight_col_stride = n_input_channels * n_output_channels;
46 const auto weight_row_stride = 3 * weight_col_stride;
47 const float *inptrs[3][3];
48 for (int i = 0; i < 3; i++)
49 {
50 for (int j = 0; j < 3; j++)
51 {
52 inptrs[i][j] = input + i*weight_row_stride + j*weight_col_stride;
53 }
54 }
55
56 // For each input channel
57 for (int ic = 0; ic < n_input_channels; ic++)
58 {
59 float *outptr = output + ic * matrix_row_stride;
60
61 // For each output channel
62 int channels_remaining = n_output_channels;
63#ifdef __aarch64__
64 for (; channels_remaining >= 4; channels_remaining -= 4)
65 {
66 // Matrices used and computed in this kernel
67 float32x4_t w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j];
68
69 // Read weights
70 for (int i = 0; i < 3; i++)
71 {
72 for (int j = 0; j < 3; j++)
73 {
74 w[i][j] = vld1q_f32(inptrs[i][j]);
75 inptrs[i][j] += 4;
76 }
77 }
78
79 // Compute the matrix W w
80 for (int j = 0; j < 3; j++)
81 {
82 Ww[0][j] = w[0][j];
83
84 // Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]);
85 Ww[1][j] = vmulq_n_f32(vaddq_f32(vaddq_f32(w[0][j], w[1][j]), w[2][j]), 0.5f);
86
87 // Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]);
88 Ww[2][j] = vmulq_n_f32(vaddq_f32(vsubq_f32(w[0][j], w[1][j]), w[2][j]), 0.5f);
89
90 Ww[3][j] = w[2][j];
91 }
92
93 // Compute V = W w WT
94 for (int i = 0; i < inner_tile_i; i++)
95 {
96 V[i][0] = Ww[i][0];
97
98 // V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]);
99 V[i][1] = vmulq_n_f32(vaddq_f32(vaddq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f);
100
101 // V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]);
102 V[i][2] = vmulq_n_f32(vaddq_f32(vsubq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f);
103
104 V[i][3] = Ww[i][2];
105 }
106
107 // Store the transformed weights
108 for (int i = 0, m = 0; i < inner_tile_i; i++)
109 {
110 for (int j = 0; j < inner_tile_j; j++, m++)
111 {
112 vst1q_f32(outptr + m*matrix_stride, V[i][j]);
113 }
114 }
115 outptr += 4;
116 }
117#endif // __aarch64__
118#ifdef __arm_any__
119 for (; channels_remaining >= 2; channels_remaining -= 2)
120 {
121 // Matrices used and computed in this kernel
122 float32x2_t w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j];
123
124 // Read weights
125 for (int i = 0; i < 3; i++)
126 {
127 for (int j = 0; j < 3; j++)
128 {
129 w[i][j] = vld1_f32(inptrs[i][j]);
130 inptrs[i][j] += 2;
131 }
132 }
133
134 // Compute the matrix W w
135 for (int j = 0; j < 3; j++)
136 {
137 Ww[0][j] = w[0][j];
138
139 // Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]);
140 Ww[1][j] = vmul_n_f32(vadd_f32(vadd_f32(w[0][j], w[1][j]), w[2][j]), 0.5f);
141
142 // Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]);
143 Ww[2][j] = vmul_n_f32(vadd_f32(vsub_f32(w[0][j], w[1][j]), w[2][j]), 0.5f);
144
145 Ww[3][j] = w[2][j];
146 }
147
148 // Compute V = W w WT
149 for (int i = 0; i < inner_tile_i; i++)
150 {
151 V[i][0] = Ww[i][0];
152
153 // V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]);
154 V[i][1] = vmul_n_f32(vadd_f32(vadd_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f);
155
156 // V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]);
157 V[i][2] = vmul_n_f32(vadd_f32(vsub_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f);
158
159 V[i][3] = Ww[i][2];
160 }
161
162 // Store the transformed weights
163 for (int i = 0, m = 0; i < inner_tile_i; i++)
164 {
165 for (int j = 0; j < inner_tile_j; j++, m++)
166 {
167 vst1_f32(outptr + m*matrix_stride, V[i][j]);
168 }
169 }
170 outptr += 2;
171 }
172#endif // __arm_any__
173 for (; channels_remaining; channels_remaining--)
174 {
175 // Matrices used and computed in this kernel
176 float w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j];
177
178 // Read weights
179 for (int i = 0; i < 3; i++)
180 {
181 for (int j = 0; j < 3; j++)
182 {
183 w[i][j] = *(inptrs[i][j]++);
184 }
185 }
186
187 // Compute the matrix W w
188 for (int j = 0; j < 3; j++)
189 {
190 Ww[0][j] = w[0][j];
191 Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]);
192 Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]);
193 Ww[3][j] = w[2][j];
194 }
195
196 // Compute V = W w WT
197 for (int i = 0; i < inner_tile_i; i++)
198 {
199 V[i][0] = Ww[i][0];
200 V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]);
201 V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]);
202 V[i][3] = Ww[i][2];
203 }
204
205 // Store the transformed weights
206 for (int i = 0, m = 0; i < inner_tile_i; i++)
207 {
208 for (int j = 0; j < inner_tile_j; j++, m++)
209 {
210 *(outptr + m*matrix_stride) = V[i][j];
211 }
212 }
213 outptr++;
214 }
215 }
216}
217
218template class WeightTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>;
219
220} // namespace