blob: 033442aa148069260c5972fbc88a2a445a0d6c34 [file] [log] [blame]
Pablo Tello89519332017-11-17 11:52:36 +00001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#pragma once
25
26namespace winograd {
27 /* Transform a kernel into the Winograd domain.
28 *
29 * NOTE: It is assumed that the kernel is in the form [height x width x
30 * input_channels x output_channel].
31 */
32 template <typename T>
33 struct winograd2x2_3x3_gemm_kernel_transform_impl{
34 static void execute(
35 const KernelShape &shape,
36 const T* const kernel,
37 T* const matrix_base,
38 const int matrix_stride,
39 const int matrix_row_stride
40 );
41
42 protected:
43 template <const int output_channel_tail>
44 static void transform_kernel(
45 const T* const kernel,
46 const int n_input_channels,
47 const int n_output_channels,
48 T* const matrix_base,
49 const int matrix_stride,
50 const int matrix_row_stride
51 );
52 };
53}
54
55/*****************************************************************************/
56/* Transform a fp32 kernel into the Winograd domain.
57 */
58#include "kernel_2x2_3x3/a64_float.hpp" // AArch64 specialisations
59
60namespace winograd
61{
62template <>
63inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::execute(
64 const KernelShape &shape,
65 const float* const kernel,
66 float* const matrix_base,
67 const int matrix_stride,
68 const int matrix_row_stride
69) {
70 // Delegate based on tail size
71 const int n_input_channels = shape.n_input_channels;
72 const int n_output_channels = shape.n_output_channels;
73
74 switch (n_output_channels % 4) {
75 case 0:
76 transform_kernel<0>(
77 kernel, n_input_channels, n_output_channels,
78 matrix_base, matrix_stride, matrix_row_stride
79 );
80 break;
81 case 1:
82 transform_kernel<1>(
83 kernel, n_input_channels, n_output_channels,
84 matrix_base, matrix_stride, matrix_row_stride
85 );
86 break;
87 case 2:
88 transform_kernel<2>(
89 kernel, n_input_channels, n_output_channels,
90 matrix_base, matrix_stride, matrix_row_stride
91 );
92 break;
93 case 3:
94 transform_kernel<3>(
95 kernel, n_input_channels, n_output_channels,
96 matrix_base, matrix_stride, matrix_row_stride
97 );
98 break;
99 default:
100 ARM_COMPUTE_ERROR("Cannot happen");
101 break;
102 }
103}
104
105template <>
106template<const int output_channel_tail>
107inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel(
108 const float* const kernel,
109 const int n_input_channels,
110 const int n_output_channels,
111 float* const matrix_base,
112 const int mstride,
113 const int matrix_row_stride
114) {
115 // Use one input pointer for each row of the kernel, use two additional
116 // offsets to extract columns.
117 const int kernel_col_stride = n_input_channels * n_output_channels;
118 const int kernel_row_stride = 3 * kernel_col_stride;
119 const float *inptr0 = kernel;
120 const float *inptr1 = kernel + kernel_row_stride;
121 const float *inptr2 = kernel + kernel_row_stride*2;
122
123 // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three
124 // offsets to extract further matrices.
125 float *outptr0 = matrix_base;
126 float *outptr4 = matrix_base + mstride * 4;
127 float *outptr8 = matrix_base + mstride * 8;
128 float *outptr12 = matrix_base + mstride * 12;
129
130 // For every input channel
131 for (int in_c = 0; in_c < n_input_channels; in_c++) {
132 // For every output channel
133 for (int c = 0; c < n_output_channels; c++) {
134 // Read in the kernel
135 float w11 = inptr0[0], w12 = inptr0[kernel_col_stride], w13 = inptr0[kernel_col_stride*2];
136 float w21 = inptr1[0], w22 = inptr1[kernel_col_stride], w23 = inptr1[kernel_col_stride*2];
137 float w31 = inptr2[0], w32 = inptr2[kernel_col_stride], w33 = inptr2[kernel_col_stride*2];
138
139 // Progress input pointers
140 inptr0++;
141 inptr1++;
142 inptr2++;
143
144 // Compute the kernel W w, note we need only compute the middle two rows
145 // (2 and 3) because the first and last rows are merely copies of values
146 // from the matrix w.
147 float Ww11 = w11, Ww12 = w12, Ww13 = w13;
148 float Ww21 = 0.5*(w11 + w21 + w31), Ww22 = 0.5*(w12 + w22 + w32), Ww23 = 0.5*(w13 + w23 + w33);
149 float Ww31 = 0.5*(w11 - w21 + w31), Ww32 = 0.5*(w12 - w22 + w32), Ww33 = 0.5*(w13 - w23 + w33);
150 float Ww41 = w31, Ww42 = w32, Ww43 = w33;
151
152 // Hence compute W w W.T; again note we need compute only the middle two
153 // columns since the first and last columns are copies of the first and
154 // last columns of the previous matrix.
155 float WwWT11 = Ww11, WwWT12 = 0.5*(Ww11 + Ww12 + Ww13), WwWT13 = 0.5*(Ww11 - Ww12 + Ww13), WwWT14 = Ww13;
156 float WwWT21 = Ww21, WwWT22 = 0.5*(Ww21 + Ww22 + Ww23), WwWT23 = 0.5*(Ww21 - Ww22 + Ww23), WwWT24 = Ww23;
157 float WwWT31 = Ww31, WwWT32 = 0.5*(Ww31 + Ww32 + Ww33), WwWT33 = 0.5*(Ww31 - Ww32 + Ww33), WwWT34 = Ww33;
158 float WwWT41 = Ww41, WwWT42 = 0.5*(Ww41 + Ww42 + Ww43), WwWT43 = 0.5*(Ww41 - Ww42 + Ww43), WwWT44 = Ww43;
159
160 // Store the computed weights
161 outptr0[0 * mstride] = WwWT11;
162 outptr0[1 * mstride] = WwWT12;
163 outptr0[2 * mstride] = WwWT13;
164 outptr0[3 * mstride] = WwWT14;
165
166 outptr4[0 * mstride] = WwWT21;
167 outptr4[1 * mstride] = WwWT22;
168 outptr4[2 * mstride] = WwWT23;
169 outptr4[3 * mstride] = WwWT24;
170
171 outptr8[0 * mstride] = WwWT31;
172 outptr8[1 * mstride] = WwWT32;
173 outptr8[2 * mstride] = WwWT33;
174 outptr8[3 * mstride] = WwWT34;
175
176 outptr12[0 * mstride] = WwWT41;
177 outptr12[1 * mstride] = WwWT42;
178 outptr12[2 * mstride] = WwWT43;
179 outptr12[3 * mstride] = WwWT44;
180
181 // Progress output pointers
182 outptr0++;
183 outptr4++;
184 outptr8++;
185 outptr12++;
186 }
187
188 // Progression to complete stride
189 outptr0 += matrix_row_stride - n_output_channels;
190 outptr4 += matrix_row_stride - n_output_channels;
191 outptr8 += matrix_row_stride - n_output_channels;
192 outptr12 += matrix_row_stride - n_output_channels;
193 }
194}
195}