blob: cfd2029f11c77ba6689b32507530d6ed649dcdf5 [file] [log] [blame]
Pablo Tello000d33a2018-09-03 16:59:20 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp"
26#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp"
27#include "arm_compute/core/NEON/kernels/convolution/common/arm.hpp"
28
29namespace winograd
30{
31
32using Transform = WinogradGEMM<1, 2, 1, 7>::OutputTransform<float>;
33using TransformTransposed = WinogradGEMM<2, 1, 7, 1>::OutputTransform<float>;
34
35template <>
36template <>
37int Transform::ops_performed(const Tensor4DShape &shape)
38{
39 (void) shape;
40 return 0; // TODO
41}
42
43template <>
44template <>
45template <int pad_bottom, int pad_right>
46void Transform::process_tile(
47 const int n_channels,
48 const float* const matrix_base,
49 const int matrix_stride,
50 const float* const biases,
51 float* const output,
52 const int output_row_stride,
53 const int output_col_stride
54)
55{
56 (void) output_row_stride;
57 constexpr int cells_j = output_tile_cols - pad_right;
58
59 // Construct a map to the output cells
60 float *outptrs[cells_j];
61 for (int j = 0; j < cells_j; j++)
62 {
63 outptrs[j] = output + j*output_col_stride;
64 }
65 const float *inptr = matrix_base;
66 const float *bptr = biases;
67
68 // For each channel of the output
69 int channels_remaining = n_channels;
70#ifdef __arm_any__
71 for (; channels_remaining >= 4; channels_remaining -= 4)
72 {
73 // Matrices used and computed during this transform
74 float32x4_t F[inner_tile_cols], f[output_tile_cols], b = vdupq_n_f32(0.0f);
75
76 // Read a 1x8 tile in the Winograd domain
77 for (int j = 0; j < inner_tile_cols; j++)
78 {
79 F[j] = vld1q_f32(inptr + j*matrix_stride);
80 }
81 inptr += 4;
82
83 f[0] = vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmulq_n_f32(F[6], 1), F[5], 1), F[4], 1), F[3], 1), F[2], 1), F[1], 1), F[0], 1);
84 f[1] = vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmulq_n_f32(F[7], 1), F[2], 1), F[6], 3), F[4], 2), F[3], -2), F[5], -3), F[1], -1);
85
86 // Write out the output tile
87 if (bptr != 0)
88 {
89 b = vld1q_f32(bptr);
90 bptr += 4;
91 }
92 for (int j = 0; j < cells_j; j++)
93 {
94 vst1q_f32(outptrs[j], f[j] + b);
95 outptrs[j] += 4;
96 }
97 }
98 for (; channels_remaining >= 2; channels_remaining -= 2)
99 {
100 // Matrices used and computed during this transform
101 float32x2_t F[inner_tile_cols], f[output_tile_cols], b = vdup_n_f32(0.0f);
102
103 // Read a 1x8 tile in the Winograd domain
104 for (int j = 0; j < inner_tile_cols; j++)
105 {
106 F[j] = vld1_f32(inptr + j*matrix_stride);
107 }
108 inptr += 2;
109
110 f[0] = vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmul_n_f32(F[6], 1), F[5], 1), F[4], 1), F[3], 1), F[2], 1), F[1], 1), F[0], 1);
111 f[1] = vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmul_n_f32(F[7], 1), F[2], 1), F[6], 3), F[4], 2), F[3], -2), F[5], -3), F[1], -1);
112
113 // Write out the output tile
114 if (bptr != 0)
115 {
116 b = vld1_f32(bptr);
117 bptr += 2;
118 }
119 for (int j = 0; j < cells_j; j++)
120 {
121 vst1_f32(outptrs[j], f[j] + b);
122 outptrs[j] += 2;
123 }
124 }
125#endif // __arm_any__
126 for (; channels_remaining; channels_remaining--)
127 {
128 // Matrices used and computed during this transform
129 float F[inner_tile_cols], f[output_tile_cols], b = 0.0f;
130
131 // Read a 1x8 tile in the Winograd domain
132 for (int j = 0; j < inner_tile_cols; j++)
133 {
134 F[j] = *(inptr + j*matrix_stride);
135 }
136 inptr++;
137
138 f[0] = F[0]*1 + F[1]*1 + F[2]*1 + F[3]*1 + F[4]*1 + F[5]*1 + F[6]*1;
139 f[1] = F[1]*-1 + F[5]*-3 + F[3]*-2 + F[4]*2 + F[6]*3 + F[2]*1 + F[7]*1;
140
141 // Write out the output tile
142 if (bptr != 0)
143 {
144 b = *(bptr++);
145 }
146 for (int j = 0; j < cells_j; j++)
147 {
148 *(outptrs[j]++) = f[j] + b;
149 }
150 }
151}
152
153template <>
154template <>
155const Transform::TileFn Transform::tile_fns[max_pad_bottom][max_pad_right] =
156{
157 {
158 Transform::template process_tile<0, 0>,
159 Transform::template process_tile<0, 1>,
160 },
161};
162
163
164template <>
165template <>
166const TransformTransposed::TileFn TransformTransposed::tile_fns[max_pad_bottom][max_pad_right] = {};
167
168template struct WinogradGEMM<1, 2, 1, 7>::OutputTransform<float>;
169template struct WinogradGEMM<2, 1, 7, 1>::OutputTransform<float>;
170} // namespace winograd