blob: 35715a3764bf49ba42499e978d0040394a0e728a [file] [log] [blame]
ramelg01c827e992022-04-08 03:52:28 +01001/*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25/* arm_conv kernels share a lot of similarities in how they address input and
26 * output tensors. Consequently, this file contains common approaches to
27 * preparing these tensor descriptions. Generic (i.e., untyped) methods are
28 * contained within the `arm_conv::addressing` namespace, and typed wrappers
29 * are provided within an anonymous namespace within `arm_conv`. The various
30 * methods are described below.
31 */
32
33#include <cstddef>
34
35namespace arm_conv {
36namespace addressing {
37
38/* Pointer array
39 * -------------
40 *
41 * Constructs an array of pointers which point to a `array_rows` x `array_cols`
42 * chunk of a tensor. The array of pointers will be written into `dest`.
43 *
44 * `base_ptr` should point at the first VALID element of the chunk of tensor
45 * (i.e., if there's one padded row, and one padded column, then `base_ptr`
46 * should point at the element which will be at position (1, 1) in the array).
47 * `ld_row` and `ld_col` are in bytes, and describe the strides over rows and
48 * columns (respectively) of the NHWC-ordered tensor. `pad_buffer` should point
49 * at a suitably sized (and initialised) area of memory which can be addressed
50 * by elements of the array which represent padding.
51 *
52 * `pad_top` and `pad_left` describe the padding on the top and left of the
53 * array, respectively, and `valid_rows` and `valid_cols` describe the number
54 * of rows and columns between the element pointed to by `base_ptr` and the
55 * edge of the image (that is `valid_rows` may be greater than `array_rows` and
56 * likewise for the columns).
57 */
58void fill_pointer_array(
59 size_t element_size,
60 void **dest, unsigned int array_rows, unsigned int array_cols,
61 void *base_ptr, size_t ld_row, size_t ld_col,
62 void *pad_buffer,
63 unsigned int pad_top, unsigned int valid_rows,
64 unsigned int pad_left, unsigned int valid_cols
65);
66
67/* Interleaved multi-point pointer array
68 * -------------------------------------
69 *
70 * For each point in a `output_rows` x `output_cols` array, constructs
71 * `kernel_rows` x `kernel_cols` array of pointers. The pointers are
72 * interleaved thusly:
73 *
74 * for ki in kernel_rows:
75 * for kj in kernel_cols:
76 * for oi in output_rows:
77 * for oj in output_cols:
78 * get pointer for point (oi*stride_rows + ki, oj*stride_cols + kj)
79 *
80 * Other arguments are as for `fill_pointer_array`.
81 *
82 * The name reflects that this is the form of addressing mode used by "generic"
83 * depthwise and pooling kernels.
84 */
85void fill_pointer_array_generic_kernel(
86 size_t element_size,
87 void **dest,
88 unsigned int output_rows, unsigned int output_cols,
89 unsigned int kernel_rows, unsigned int kernel_cols,
90 unsigned int stride_rows, unsigned int stride_cols,
91 void *base_ptr, size_t ld_row, size_t ld_col,
92 void *pad_buffer,
93 unsigned int pad_top, unsigned int valid_rows,
94 unsigned int pad_left, unsigned int valid_cols
95);
96
97/* NCHW-patch addressed by row
98 * ---------------------------
99 *
100 * Construct an array of pointers, each of which points at a row of an
101 * NCHW-ordered patch of a tensor. Memory addressed by the pointers may be
102 * outside of the original tensor, and should therefore not be written to
103 * (modifications will be lost).
104 *
105 * `dest_row_pointers` should point at a `patch_rows` list of pointers; each of
106 * which will point at a 1 x `patch_cols` NCHW-ordered sample of the source
107 * tensor.
108 *
109 * `dest_patch` should point to a `element_size * patch_rows * patch_cols` area
110 * of memory which can be written to by this function to form samples of the
111 * source tensor.
112 *
113 * `src_ptr` should point at the first VALID element of the chunk of tensor
114 * (i.e., if there's one padded row, and one padded column, then `src_ptr`
115 * should point at the element which will be at position (1, 1) in the array).
116 * `ld_row` and `ld_col` are in bytes, and describe the strides over rows and
117 * columns (respectively) of the NHWC-ordered tensor. If `ld_col` ==
118 * `element_size` then copies from the source tensor will be elided and source
119 * data may be addressed directly.
120 *
121 * `pad_row` should point to a `patch_cols` array of (appropriately
122 * initialised) padding values.
123 *
124 * Other arguments are as for `fill_pointer_array`.
125 */
126void fill_nchw_patch_array(
127 size_t element_size,
128 const void **dest_row_pointers, // Array of pointers to each row of the patch
129 void *dest_patch, // Pointer to space which can be used to construct the patch
130 unsigned int patch_rows, unsigned int patch_cols, // Patch size
131 const void *src_ptr, size_t ld_row, size_t ld_col, // Source tensor
132 const void *pad_row, // Pointer to a row of padding values
133 unsigned int pad_top, unsigned int valid_rows,
134 unsigned int pad_left, unsigned int valid_cols
135);
136
137void fill_patch_array_generic_kernel(
138 size_t element_size,
139 const void **dest_pointers, // Pointers: one per output row per kernel point
140 void *dest_patch, // Pointer to space which can be used to construct the patch
141 unsigned int output_rows, unsigned int output_cols,
142 unsigned int kernel_rows, unsigned int kernel_cols,
143 unsigned int stride_rows, unsigned int stride_cols,
144 const void *src_ptr, size_t ld_row, size_t ld_col, // Source tensor
145 const void *pad_row, // Pointer to a row of padding values
146 unsigned int pad_top, unsigned int valid_rows,
147 unsigned int pad_left, unsigned int valid_cols
148);
149
150} // namespace addressing
151
152namespace {
153
154/* Pointer array
155 * -------------
156 *
157 * See `addressing::fill_pointer_array`. No copies are made by this method,
158 * memory pointed to by the pointer array is contained within the base tensor
159 * and the padding buffer.
160 */
161template <typename T>
162inline void fill_pointer_array(
163 T **dest, unsigned int array_rows, unsigned int array_cols,
164 T *base_ptr, size_t ld_row, size_t ld_col,
165 T *pad_buffer,
166 unsigned int pad_top, unsigned int valid_rows,
167 unsigned int pad_left, unsigned int valid_cols
168)
169{
170 addressing::fill_pointer_array(
171 sizeof(T), (void **) dest, array_rows, array_cols,
172 (void *) base_ptr, ld_row, ld_col,
173 (void *) pad_buffer,
174 pad_top, valid_rows,
175 pad_left, valid_cols
176 );
177}
178
179
180/* Interleaved multi-point pointer array
181 * -------------------------------------
182 *
183 * See `addressing::fill_pointer_array_generic_kernel`. No copies are made by
184 * this method, memory pointed to by the pointer array is contained within the
185 * base tensor and the padding buffer.
186 */
187template <typename T>
188inline void fill_pointer_array_generic_kernel(
189 T **dest,
190 unsigned int output_rows, unsigned int output_cols,
191 unsigned int kernel_rows, unsigned int kernel_cols,
192 unsigned int stride_rows, unsigned int stride_cols,
193 T *base_ptr, size_t ld_row, size_t ld_col,
194 T *pad_buffer,
195 unsigned int pad_top, unsigned int valid_rows,
196 unsigned int pad_left, unsigned int valid_cols
197)
198{
199 addressing::fill_pointer_array_generic_kernel(
200 sizeof(T),
201 (void **) dest,
202 output_rows, output_cols,
203 kernel_rows, kernel_cols,
204 stride_rows, stride_cols,
205 (void *) base_ptr, ld_row, ld_col,
206 (void *) pad_buffer,
207 pad_top, valid_rows,
208 pad_left, valid_cols
209 );
210}
211
212template <typename T>
213inline void fill_nchw_patch_array(
214 const T **dest_row_pointers, // Array of pointers to each row of the patch
215 T *dest_patch, // Pointer to space which can be used to construct the patch
216 unsigned int patch_rows, unsigned int patch_cols, // Patch size
217 const T *src_ptr, size_t ld_row, size_t ld_col, // Source tensor
218 const T *pad_row, // Pointer to a row of padding values
219 unsigned int pad_top, unsigned int valid_rows,
220 unsigned int pad_left, unsigned int valid_cols
221)
222{
223 addressing::fill_nchw_patch_array(
224 sizeof(T),
225 reinterpret_cast<const void **>(dest_row_pointers),
226 reinterpret_cast<void *>(dest_patch),
227 patch_rows, patch_cols,
228 reinterpret_cast<const void *>(src_ptr), ld_row, ld_col,
229 reinterpret_cast<const void *>(pad_row),
230 pad_top, valid_rows,
231 pad_left, valid_cols
232 );
233}
234
235template <typename T>
236inline void fill_patch_array_generic_kernel(
237 const T **dest_pointers, // Pointers: one per output row per kernel point
238 T *dest_patch, // Pointer to space which can be used to construct the patch
239 unsigned int output_rows, unsigned int output_cols,
240 unsigned int kernel_rows, unsigned int kernel_cols,
241 unsigned int stride_rows, unsigned int stride_cols,
242 const T *src_ptr, size_t ld_row, size_t ld_col, // Source tensor
243 const T *pad_row, // Pointer to a row of padding values
244 unsigned int pad_top, unsigned int valid_rows,
245 unsigned int pad_left, unsigned int valid_cols
246)
247{
248 addressing::fill_patch_array_generic_kernel(
249 sizeof(T),
250 reinterpret_cast<const void **>(dest_pointers),
251 reinterpret_cast<void *>(dest_patch),
252 output_rows, output_cols,
253 kernel_rows, kernel_cols,
254 stride_rows, stride_cols,
255 reinterpret_cast<const void *>(src_ptr), ld_row, ld_col,
256 reinterpret_cast<const void *>(pad_row),
257 pad_top, valid_rows,
258 pad_left, valid_cols
259 );
260}
261
262} // namespace {anonymous}
263} // namespace arm_conv