blob: 8e4bebcd208e39a82ec7dad0ba27f2edeccbccde [file] [log] [blame]
Pablo Tello8f43d742019-03-27 09:28:32 +00001/*
Pablo Tello5264b7d2019-10-21 14:25:41 +01002 * Copyright (c) 2017-2019 ARM Limited.
Pablo Tello8f43d742019-03-27 09:28:32 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#pragma once
26
Pablo Tello5264b7d2019-10-21 14:25:41 +010027#include <algorithm>
28
Pablo Tello8f43d742019-03-27 09:28:32 +000029#include "padding.hpp"
Pablo Tello5264b7d2019-10-21 14:25:41 +010030#include "utils.hpp"
31#include "winograd.hpp"
Pablo Tello8f43d742019-03-27 09:28:32 +000032
33#define MEMBERFN(RTYPE) template <\
34 int InnerTileRows, int InnerTileCols,\
35 typename TIn, typename TOut, WinogradRoots Roots\
36> RTYPE InputTransform<InnerTileRows, InnerTileCols, TIn, TOut, Roots>
37
38
39#define Nx1MEMBERFN(RTYPE) template <\
40 int InnerTileRows, typename TIn, typename TOut, WinogradRoots Roots\
41> RTYPE InputTransform<InnerTileRows, 1, TIn, TOut, Roots>
42
43namespace winograd
44{
45
46MEMBERFN()::InputTransform(
47 const int kernel_rows,
48 const int kernel_cols,
49 const int n_batches,
50 const int n_rows,
51 const int n_cols,
52 const int n_channels,
53 const int padding_top,
54 const int padding_left,
55 const int padding_bottom,
56 const int padding_right
57) : _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels),
58 _inptr(nullptr), _outptr(nullptr),
59 _overlap_rows(kernel_rows - 1), _overlap_cols(kernel_cols - 1),
60 _padding_top(padding_top), _padding_left(padding_left), _padding_bottom(padding_bottom), _padding_right(padding_right),
61 _tiles_M(iceildiv(padding_top + n_rows + padding_bottom - kernel_rows + 1, InnerTileRows - kernel_rows + 1)),
62 _tiles_N(iceildiv(padding_left + n_cols + padding_right - kernel_cols + 1, InnerTileCols - kernel_cols + 1)),
63 _matrix_stride(0), _matrix_row_stride(0), _matrix_batch_stride(0),
64 _in_col_stride(0), _in_row_stride(0), _in_batch_stride(0),
65 _working_space_col_stride(n_channels),
66 _working_space_row_stride(InnerTileCols * _working_space_col_stride),
67 _working_space(nullptr)
68{
69}
70
71MEMBERFN(void)::set_input_tensor(const void* const inptr)
72{
73 set_input_tensor(inptr, _n_channels);
74}
75
76MEMBERFN(void)::set_input_tensor(const void* const inptr, const int ldcol)
77{
78 set_input_tensor(inptr, _n_cols * ldcol, ldcol);
79}
80
81MEMBERFN(void)::set_input_tensor(const void* const inptr, const int ldrow, const int ldcol)
82{
83 set_input_tensor(inptr, _n_rows * ldrow, ldrow, ldcol);
84}
85
86MEMBERFN(void)::set_input_tensor(const void* const inptr, const int ldbatch, const int ldrow, const int ldcol)
87{
88 _inptr = static_cast<const TIn *>(inptr);
89 _in_batch_stride = ldbatch;
90 _in_row_stride = ldrow;
91 _in_col_stride = ldcol;
92}
93
94MEMBERFN(void)::set_output_matrices(void * const mptr, const int ldmatrix, const int ldrow)
95{
96 _outptr = static_cast<TOut *>(mptr);
97 _matrix_stride = ldmatrix;
98 _matrix_row_stride = ldrow;
99 _matrix_batch_stride = _tiles_M * _tiles_N * ldrow;
100}
101
102Nx1MEMBERFN()::InputTransform(
103 const int kernel_rows,
104 const int kernel_cols,
105 const int n_batches,
106 const int n_rows,
107 const int n_cols,
108 const int n_channels,
109 const int padding_top,
110 const int padding_left,
111 const int padding_bottom,
112 const int padding_right
113) : InputTransform<1, InnerTileRows, TIn, TOut, Roots>::InputTransform(
114 /* Transpose rows and columns */
115 kernel_cols, kernel_rows, n_batches, n_cols, n_rows, n_channels,
116 padding_left, padding_top, padding_right, padding_bottom
117 )
118{
119}
120
121Nx1MEMBERFN(void)::set_input_tensor(const void* const inptr)
122{
123 set_input_tensor(inptr, this->_n_channels);
124}
125
126Nx1MEMBERFN(void)::set_input_tensor(const void* const inptr, const int ldcol)
127{
128 set_input_tensor(inptr, this->_n_cols * ldcol, ldcol);
129}
130
131Nx1MEMBERFN(void)::set_input_tensor(const void* const inptr, const int ldrow, const int ldcol)
132{
133 set_input_tensor(inptr, this->_n_rows * ldrow, ldrow, ldcol);
134}
135
136Nx1MEMBERFN(void)::set_input_tensor(const void* const inptr, const int ldbatch, const int ldrow, const int ldcol)
137{
138 // Transpose row and column strides
139 Base::set_input_tensor(inptr, ldbatch, ldcol, ldrow);
140}
141
142MEMBERFN(size_t)::get_working_space_size(const unsigned int nthreads) const
143{
144 return sizeof(TIn) * InnerTileRows * _working_space_row_stride * nthreads;
145}
146
147MEMBERFN(void)::set_working_space(void * const buffer)
148{
149 _working_space = static_cast<TIn *>(buffer);
150}
151
152MEMBERFN(unsigned int)::get_window(void) const
153{
154 return iceildiv(_n_channels, WINDOW_BLOCK);
155}
156
157MEMBERFN(void)::run(
158 const unsigned int start,
159 const unsigned int stop,
160 const unsigned int threadid
161)
162{
163 // Determine the channels on which to work
164 if (start >= get_window())
165 {
166 return; // No work to do beyond the end of the window
167 }
168 const unsigned int start_channel = start * WINDOW_BLOCK;
169 const unsigned int stop_channel = std::min<unsigned int>(_n_channels , stop * WINDOW_BLOCK);
170 const unsigned int n_channels = stop_channel - start_channel;
171
172 // Loop over batches
173 for (int batch = 0; batch < _n_batches; batch++)
174 {
175 const TIn* const inptr_batch = _inptr + start_channel + batch*_in_batch_stride;
176 TOut* const outptr_batch = _outptr + start_channel + batch*_matrix_batch_stride;
177
178 // Loop over rows of tiles
179 for (int tile_i = 0; tile_i < _tiles_M; tile_i++)
180 {
181 // Compute the starting and ending row of pixels within the row of tiles,
182 // hence compute the padding to apply to the top and bottom of each tile.
183 const int row_top = tile_i * (InnerTileRows - _overlap_rows) - _padding_top;
184 const int row_bottom = row_top + InnerTileRows;
185 const int row_pad_top = std::max(0, _padding_top - tile_i * (InnerTileRows - _overlap_rows));
186 const int row_pad_bottom = std::max(0, row_bottom - _n_rows);
187
188 // Get a pointer to the start of the row.
189 const int row_offset = std::min(0, row_pad_top - _padding_top);
190 const TIn* const inptr_row = inptr_batch + _in_row_stride*(row_offset + tile_i*(InnerTileRows - _overlap_rows));
191 TOut* const outptr_row = outptr_batch + tile_i*_tiles_N*_matrix_row_stride;
192
193 // Loop over tiles within the row
194 for (int tile_j = 0; tile_j < _tiles_N; tile_j++)
195 {
196 // Compute the starting and ending column of pixels within the tile,
197 // hence compute the padding to apply to the left and right of the
198 // tile.
199 const int tile_left = tile_j * (InnerTileCols - _overlap_cols) - _padding_left;
200 const int tile_right = tile_left + InnerTileCols;
201 const int tile_pad_left = std::max(0, _padding_left - tile_j * (InnerTileCols - _overlap_cols));
202 const int tile_pad_right = std::max(0, tile_right - _n_cols);
203
204 // Get a pointer to the start of the tile.
205 const int col_offset = std::min(0, tile_pad_left - _padding_left);
206 const TIn* const inptr_tile = inptr_row + _in_col_stride*(col_offset + tile_j*(InnerTileCols - _overlap_cols));
207 TOut* const outptr_tile = outptr_row + tile_j * _matrix_row_stride;
208
209 // Transform the tile, applying padding if necessary.
210 if (row_pad_top || tile_pad_left || row_pad_bottom || tile_pad_right)
211 {
212 transform_padded_tile(
213 threadid, n_channels, outptr_tile, inptr_tile,
214 row_pad_top, tile_pad_left, row_pad_bottom, tile_pad_right
215 );
216 }
217 else
218 {
219 transform_unpadded_tile(threadid, n_channels, outptr_tile, inptr_tile);
220 }
221 }
222 }
223 }
224}
225
226MEMBERFN(void)::transform_unpadded_tile(
227 const unsigned int /* threadid unused */,
228 const int n_channels,
229 TOut * const outptr,
230 const TIn * const inptr
231)
232{
233 transform_tile(
234 n_channels, inptr, _in_row_stride, _in_col_stride, outptr, _matrix_stride
235 );
236}
237
238MEMBERFN(void)::transform_padded_tile(
239 const unsigned int threadid,
240 const int n_channels,
241 TOut * const outptr,
242 const TIn * const inptr,
243 const int padding_top,
244 const int padding_left,
245 const int padding_bottom,
246 const int padding_right
247)
248{
249 padding::copy_and_pad_tile(
250 InnerTileRows, InnerTileCols, n_channels,
251 inptr, _in_row_stride, _in_col_stride,
252 static_cast<TIn *>(get_working_space(threadid)), _working_space_row_stride, _working_space_col_stride,
253 padding_top, padding_left, padding_bottom, padding_right
254 );
255
256 transform_tile(
257 n_channels, static_cast<const TIn *>(get_working_space(threadid)),
258 _working_space_row_stride, _working_space_col_stride,
259 outptr, _matrix_stride
260 );
261}
262
263MEMBERFN(void *)::get_working_space(const unsigned int threadid) const
264{
265 return _working_space + InnerTileRows * _working_space_row_stride * threadid;
266}
267
268} // namespace winograd