blob: bc0d9d42962534134796c53e0d04cd4137f6d72f [file] [log] [blame]
Pablo Tello8f43d742019-03-27 09:28:32 +00001/*
Pablo Tello5264b7d2019-10-21 14:25:41 +01002 * Copyright (c) 2017-2019 ARM Limited.
Pablo Tello8f43d742019-03-27 09:28:32 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#pragma once
26
Pablo Tello5264b7d2019-10-21 14:25:41 +010027#include "arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp"
28
29#include <cstddef>
30#include <utility>
Pablo Tello8f43d742019-03-27 09:28:32 +000031
32namespace winograd
33{
34
35class ITransform
36{
37 public:
38 virtual ~ITransform() = default;
39
40 /**
41 * Get the working space required to perform the transformation.
42 *
43 * Note, the working space is only required when performing the
44 * transformation - hence it can be reused whenever the transformation is
45 * not running.
46 *
47 * @param nthreads The greatest number of threads that will be used to execute the transform.
48 * @return Size of working space required in bytes.
49 */
50 virtual size_t get_working_space_size(unsigned int nthreads=1) const = 0;
51
52 /**
53 * Set the working space to be used by the transformation.
54 *
55 * Note, the working space is only required when performing the
56 * transformation - hence it can be reused whenever the transformation is
57 * not running.
58 *
59 * @param Pointer to the working space.
60 */
61 virtual void set_working_space(void *buffer) = 0;
62
63 /**
64 * Get the window of work a given operator can perform.
65 */
66 virtual unsigned int get_window() const = 0;
67
68 /**
69 * Perform work upon a window of the transform.
70 */
71 virtual void run(unsigned int start, unsigned int stop, unsigned int threadid=0) = 0;
72};
73
74class IInputTransform : public ITransform
75{
76 public:
77 virtual ~IInputTransform() = default;
78
79 /**
80 * Set the pointer to the (NHWC-ordered) tensor to be transformed.
81 */
82 virtual void set_input_tensor(const void *input) = 0;
83
84 /**
85 * Set the pointer to the (NHWC-ordered) tensor to be transformed.
86 * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
87 */
88 virtual void set_input_tensor(const void *input, int col_stride) = 0;
89
90 /**
91 * Set the pointer to the (NHWC-ordered) tensor to be transformed.
92 * @param row_stride Stride between rows of the tensor, measured in elements (not bytes).
93 * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
94 */
95 virtual void set_input_tensor(const void *input, int row_stride, int col_stride) = 0;
96
97 /**
98 * Set the pointer to the (NHWC-ordered) tensor to be transformed.
99 * @param batch_stride Stride between batches of the tensor, measured in elements (not bytes).
100 * @param row_stride Stride between rows of the tensor, measured in elements (not bytes).
101 * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
102 */
103 virtual void set_input_tensor(const void *input, int batch_stride, int row_stride, int col_stride) = 0;
104
105 /**
106 * Set pointers to the matrices written by the transform.
107 * @param matrices Pointer to the start of the first matrix representing the transformed input.
108 * @param inter_matrix_stride Stride (in elements) between matrices.
109 * @param matrix_row_stride Stride (in elements) between the rows within a single matrix.
110 */
111 virtual void set_output_matrices(void *matrices, int inter_matrix_stride, int matrix_row_stride) = 0;
112};
113
114class IOutputTransform : public ITransform
115{
116 public:
117 virtual ~IOutputTransform() = default;
118
119 /**
120 * Set pointers to the matrices written by the transform.
121 * @param matrices Pointer to the start of the first matrix representing the input to the transform.
122 * @param inter_matrix_stride Stride (in elements) between matrices.
123 * @param matrix_row_stride Stride (in elements) between the rows within a single matrix.
124 */
125 virtual void set_input_matrices(const void *matrices, int inter_matrix_stride, int matrix_row_stride) = 0;
126
127 /**
128 * Set pointer to the bias tensor (can be ignored or called with nullptr for no bias.
129 */
130 virtual void set_bias(const void *bias=nullptr) = 0;
131
132 /**
133 * Set pointer to the output tensor produced by the transform.
134 */
135 virtual void set_output_tensor(void *output) = 0;
136
137 /**
138 * Set pointer to the output tensor produced by the transform.
139 * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
140 */
141 virtual void set_output_tensor(void *output, int col_stride) = 0;
142
143 /**
144 * Set pointer to the output tensor produced by the transform.
145 * @param row_stride Stride between rows of the tensor, measured in elements (not bytes).
146 * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
147 */
148 virtual void set_output_tensor(void *output, int row_stride, int col_stride) = 0;
149
150 /**
151 * Set pointer to the output tensor produced by the transform.
152 * @param batch_stride Stride between batches of the tensor, measured in elements (not bytes).
153 * @param row_stride Stride between rows of the tensor, measured in elements (not bytes).
154 * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
155 */
156 virtual void set_output_tensor(void *output, int batch_stride, int row_stride, int col_stride) = 0;
157};
158
159class IWeightTransform : public ITransform
160{
161 public:
162 virtual ~IWeightTransform() = default;
163
164 /** Set pointer to the weight tensor read by the transform. */
165 virtual void set_weight_tensor(const void *weights) = 0;
166
167 /**
168 * Set pointers to the matrices written by the transform.
169 * @param matrices Pointer to the start of the first matrix representing the transformed input.
170 * @param inter_matrix_stride Stride (in elements) between matrices.
171 * @param matrix_row_stride Stride (in elements) between the rows within a single matrix.
172 */
173 virtual void set_output_matrices(void *matrices, int inter_matrix_stride, int matrix_row_stride) = 0;
174};
175
176enum class WinogradRoots
177{
178 Integers,
179};
180
181template <int InnerTileRows, int InnerTileCols, typename TIn, typename TOut, WinogradRoots Roots>
182class InputTransform : public IInputTransform
183{
184 public:
185 /** Create an InputTransform operator fixed on a given problem and set of
186 * pointers.
187 */
188 InputTransform(
189 int kernel_rows, /**< Number of rows in the kernel */
190 int kernel_cols, /**< Number of columns in the kernel */
191 int n_batches, /**< Number of batches in input tensor. */
192 int n_rows, /**< Number of rows in input tensor. */
193 int n_cols, /**< Number of columns in input tensor. */
194 int n_channels, /**< Number of channels in input tensor. */
195 int padding_top, /**< Padding to apply to the top of the image. */
196 int padding_left, /**< Padding to apply to the left of the image. */
197 int padding_bottom, /**< Padding to apply to the bottom of the image. */
198 int padding_right /**< Padding to apply to the right of the image. */
199 );
200
201 InputTransform(InputTransform&) = delete;
202 InputTransform operator=(InputTransform&) = delete;
203
204 /** Set pointers to the input tensor read by the transform. */
205 void set_input_tensor(const void *input) override;
206 void set_input_tensor(const void *input, int col_stride) override;
207 void set_input_tensor(const void *input, int row_stride, int col_stride) override;
208 void set_input_tensor(const void *input, int batch_stride, int row_stride, int col_stride) override;
209
210 /** Set pointers to the matrices written by the transform. */
211 void set_output_matrices(void *matrices, int iter_matrix_stride, int matrix_row_stride) override;
212
213 /** Get the working space required to perform the transformation. */
214 size_t get_working_space_size(unsigned int nthreads=1) const override;
215 void set_working_space(void *buffer) override;
216
217 /** Get the window of work a given operator can perform. */
218 unsigned int get_window() const override;
219 static constexpr unsigned int WINDOW_BLOCK = 16; // Base size of window
220
221 /** Perform work upon a window of the input. */
222 void run(unsigned int start, unsigned int stop, unsigned int threadid=0) override;
223
224 protected:
225 const int _n_batches, _n_rows, _n_cols, _n_channels;
226
227 private:
228 void transform_unpadded_tile(
229 unsigned int threadid,
230 int n_channels,
231 TOut *outptr,
232 const TIn *inptr
233 );
234
235 void transform_padded_tile(
236 unsigned int threadid,
237 int n_channels,
238 TOut *outptr,
239 const TIn *inptr,
240 int padding_top,
241 int padding_left,
242 int padding_bottom,
243 int padding_right
244 );
245
246 /* Tile implementation */
247 static void transform_tile(
248 int n_channels, /** @param[in] Number of channels in the tensor. */
249 const TIn* inptr_base, /** @param[in] Pointer to the base of the input tile. */
250 int input_row_stride, /** @param[in] Stride between rows of the input tensor. */
251 int input_col_stride, /** @param[in] Stride between columns of the input tensor. */
252 TOut* mptr_base, /** @param[out] Base pointer to transformed input matrices. */
253 int matrix_stride /** @param[in] Stride between matrices in the input space. */
254 );
255
256 /** Get the working space for a thread. */
257 void * get_working_space(unsigned int threadid) const;
258
259 const TIn* _inptr;
260 TOut* _outptr;
261
262 const int _overlap_rows, _overlap_cols;
263 const int _padding_top, _padding_left, _padding_bottom, _padding_right;
264 const int _tiles_M, _tiles_N;
265 int _matrix_stride, _matrix_row_stride, _matrix_batch_stride;
266 int _in_col_stride, _in_row_stride, _in_batch_stride;
267
268 const int _working_space_col_stride, _working_space_row_stride;
269 TIn *_working_space;
270};
271
272template <int InnerTileRows, typename TIn, typename TOut, WinogradRoots Roots>
273class InputTransform<InnerTileRows, 1, TIn, TOut, Roots> :
274 public InputTransform<1, InnerTileRows, TIn, TOut, Roots>
275{
276 using Base = InputTransform<1, InnerTileRows, TIn, TOut, Roots>;
277
278 public:
279 InputTransform(
280 int kernel_rows, /**< Number of rows in the kernel. */
281 int kernel_cols, /**< Number of columns in the kernel. */
282 int n_batches, /**< Number of batches in input tensor. */
283 int n_rows, /**< Number of rows in input tensor. */
284 int n_cols, /**< Number of columns in input tensor. */
285 int n_channels, /**< Number of channels in input tensor. */
286 int padding_top, /**< Padding to apply to the top of the image. */
287 int padding_left, /**< Padding to apply to the left of the image. */
288 int padding_bottom, /**< Padding to apply to the bottom of the image. */
289 int padding_right /**< Padding to apply to the right of the image. */
290 );
291
292 /** Set pointers to the input tensor read by the transform. */
293 void set_input_tensor(const void *input) override;
294 void set_input_tensor(const void *input, int col_stride) override;
295 void set_input_tensor(const void *input, int row_stride, int col_stride) override;
296 void set_input_tensor(const void *input, int batch_stride, int row_stride, int col_stride) override;
297};
298
299template <
300 int KernelRows, int KernelCols,
301 int InnerTileRows, int InnerTileCols,
302 typename TIn, typename TOut,
303 WinogradRoots Roots
304>
305class OutputTransform : public IOutputTransform
306{
307 public:
308 OutputTransform(
309 int n_batches, /**< Number of batches in output tensor. */
310 int n_rows, /**< Number of rows in output tensor. */
311 int n_cols, /**< Number of columns in output tensor. */
Pablo Tello5264b7d2019-10-21 14:25:41 +0100312 int n_channels, /**< Number of channels in output tensor. */
313 const arm_gemm::Activation &activation
Pablo Tello8f43d742019-03-27 09:28:32 +0000314 );
315
316 OutputTransform(OutputTransform&) = delete;
317 OutputTransform operator=(OutputTransform&) = delete;
318
319 /** Set pointers to the matrices read by the transform. */
320 void set_input_matrices(const void *matrices, int iter_matrix_stride, int matrix_row_stride) override;
321
322 /** Set pointer to the bias tensor (can be ignored or called with nullptr for no bias */
323 void set_bias(const void *bias=nullptr) override;
324
325 /** Set pointers to the output tensor written by the transform. */
326 void set_output_tensor(void *output) override;
327 void set_output_tensor(void *output, int col_stride) override;
328 void set_output_tensor(void *output, int row_stride, int col_stride) override;
329 void set_output_tensor(void *output, int batch_stride, int row_stride, int col_stride) override;
330
331 /** Get the working space required to perform the transformation. */
332 size_t get_working_space_size(unsigned int nthreads=1) const override;
333 void set_working_space(void *buffer) override;
334
335 /** Get the window of work a given operator can perform. */
336 unsigned int get_window() const override;
337 static constexpr unsigned int WINDOW_BLOCK = 16; // Base size of window
338
339 /** Perform work upon a window of the input. */
340 void run(unsigned int start, unsigned int stop, unsigned int threadid=0) override;
341
342 protected:
343 static constexpr int inner_tile_rows = InnerTileRows;
344 static constexpr int inner_tile_cols = InnerTileCols;
345 static constexpr int output_tile_rows = InnerTileRows - KernelRows + 1;
346 static constexpr int output_tile_cols = InnerTileCols - KernelCols + 1;
347
348 const int _n_batches, _n_rows, _n_cols, _n_channels;
Pablo Tello5264b7d2019-10-21 14:25:41 +0100349 const TOut _output_min, _output_max;
Pablo Tello8f43d742019-03-27 09:28:32 +0000350
351 private:
352 void transform_uncropped_tile(
353 unsigned int threadid,
354 int n_channels,
355 TOut *outptr,
356 const TIn *inptr,
357 const TOut *biases
358 );
359
360 void transform_cropped_tile(
361 unsigned int threadid,
362 int n_channels,
363 TOut *outptr,
364 const TIn *inptr,
365 const TOut *biases,
366 int pad_bottom,
367 int pad_right
368 );
369
370 /** Implementation of the tile transformation method. */
371 static void transform_tile(
372 int n_channels,
373 const TIn* matrix_base,
374 int matrix_stride,
375 const TOut* biases,
376 TOut* output,
377 int output_row_stride,
Pablo Tello5264b7d2019-10-21 14:25:41 +0100378 int output_col_stride,
379 TOut output_min,
380 TOut output_max
Pablo Tello8f43d742019-03-27 09:28:32 +0000381 );
382
383 /** Get the working space for a thread. */
384 void * get_working_space(unsigned int threadid) const;
385
386 const TIn* _matrix_base;
387 const TOut* _biases;
388 int _matrix_stride, _matrix_row_stride, _matrix_batch_stride;
389 TOut* _outptr;
390 const int _tiles_M, _tiles_N;
391 int _out_col_stride, _out_row_stride, _out_batch_stride;
392
393 const int _working_space_col_stride, _working_space_row_stride;
394 TOut *_working_space;
395};
396
397template <
398 int KernelRows,
399 int InnerTileRows,
400 typename TIn, typename TOut,
401 WinogradRoots Roots
402>
403class OutputTransform<KernelRows, 1, InnerTileRows, 1, TIn, TOut, Roots> :
404 public OutputTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>
405{
406 using Base = OutputTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>;
407
408 public:
409 OutputTransform(
410 int n_batches, /**< Number of batches in output tensor. */
411 int n_rows, /**< Number of rows in output tensor. */
412 int n_cols, /**< Number of columns in output tensor. */
Pablo Tello5264b7d2019-10-21 14:25:41 +0100413 int n_channels, /**< Number of channels in output tensor. */
414 const arm_gemm::Activation &activation
Pablo Tello8f43d742019-03-27 09:28:32 +0000415 );
416
417 /** Set pointers to the output tensor written by the transform. */
418 void set_output_tensor(void *output) override;
419 void set_output_tensor(void *output, int col_stride) override;
420 void set_output_tensor(void *output, int row_stride, int col_stride) override;
421 void set_output_tensor(void *output, int batch_stride, int row_stride, int col_stride) override;
422};
423
424template <
425 int KernelRows, int KernelCols,
426 int InnerTileRows, int InnerTileCols,
427 typename TIn, typename TOut,
428 WinogradRoots Roots
429>
430class WeightTransform : public IWeightTransform
431{
432 public:
433 WeightTransform(
434 int n_output_channels, /**< Number of output channels in the kernel. */
435 int n_input_channels /**< Number of input channels in the kernel. */
436 );
437
438 WeightTransform(WeightTransform&) = delete;
439 WeightTransform operator=(WeightTransform&) = delete;
440
441 /** Set pointer to the weight tensor read by the transform. */
442 void set_weight_tensor(const void *weights) override;
443
444 /** Set pointer to the matrices written by the transform. */
445 void set_output_matrices(void *matrices, int inter_matrix_stride, int matrix_row_stride) override;
446
447 /** Get the working space required to perform the transformation. */
448 size_t get_working_space_size(unsigned int nthreads=1) const override;
449 void set_working_space(void *buffer) override;
450
451 /** Get the window of work a given operator can perform. */
452 unsigned int get_window() const override;
453 static constexpr unsigned int WINDOW_BLOCK = 16; // Base size of window
454
455 /** Perform work upon a window of the input. */
456 void run(unsigned int start, unsigned int stop, unsigned int threadid=0) override;
457
458 protected:
459 static const int kernel_rows = KernelRows;
460 static const int kernel_cols = KernelCols;
461 static const int inner_tile_rows = InnerTileRows;
462 static const int inner_tile_cols = InnerTileCols;
463
464 private:
465 /** Apply the transform to a tensor. */
466 static void execute(
467 int n_output_channels,
468 int n_input_channels,
469 const TIn* input,
470 TOut* output,
471 int matrix_stride,
472 int matrix_row_stride
473 );
474
475 const int _n_output_channels, _n_input_channels;
476 TOut *_matrices;
477 int _matrix_stride, _matrix_row_stride;
478 const TIn *_weights;
479};
480
481template <int KernelRows, int InnerTileRows, typename TIn, typename TOut, WinogradRoots Roots>
482class WeightTransform<KernelRows, 1, InnerTileRows, 1, TIn, TOut, Roots> :
483 public WeightTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>
484{
485 public:
486 using WeightTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>::WeightTransform;
487};
488
489template <int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols, WinogradRoots Roots>
490class WinogradGEMM
491{
492 public:
493 // Information about the specific Winograd instance
494 static constexpr int output_tile_rows = OutputTileRows;
495 static constexpr int output_tile_cols = OutputTileCols;
496 static constexpr int kernel_rows = KernelRows;
497 static constexpr int kernel_cols = KernelCols;
498 static constexpr int inner_tile_rows = output_tile_rows + kernel_rows - 1;
499 static constexpr int inner_tile_cols = output_tile_cols + kernel_cols - 1;
500 static constexpr int N_GEMMS = inner_tile_rows * inner_tile_cols;
501
502 /** Transform weights from the spatial to the Winograd domain. */
503 template <typename TIn, typename TOut>
504 using WeightsTransform = WeightTransform<
505 KernelRows, KernelCols, inner_tile_rows, inner_tile_cols,
506 TIn, TOut, Roots
507 >;
508
509 /** Transform input feature maps from the spatial to the Winograd domain.
510 */
511 template <typename TIn, typename TOut>
512 using InputTransform = InputTransform<
513 inner_tile_rows, inner_tile_cols, TIn, TOut, Roots
514 >;
515
516 /** Transform output feature maps from the Winograd to the spatial domain.
517 */
518 template <typename TIn, typename TOut>
519 using OutputTransform = OutputTransform<
520 KernelRows, KernelCols, inner_tile_rows, inner_tile_cols,
521 TIn, TOut, Roots
522 >;
523
524 /** Perform a convolution.
525 */
526 template <typename TOut, typename TIn, typename TInGEMM=TIn, typename TOutGEMM=TOut>
527 class Convolution
528 {
529 public:
530 // Information about the typed Winograd instance
531 typedef TOut OutputType;
532 typedef TOutGEMM GemmOutputType;
533 typedef TInGEMM GemmInputType;
534 typedef TIn InputType;
535
536 /** Get the output shape of a convolution. */
Pablo Tello5264b7d2019-10-21 14:25:41 +0100537 static std::pair<unsigned int, unsigned int> get_output_shape(
538 const std::pair<unsigned int, unsigned int> input_shape,
539 bool padding_same);
Pablo Tello8f43d742019-03-27 09:28:32 +0000540
541 /** Get the memory required to store the kernel transformed into the
542 * Winograd domain.
543 */
Pablo Tello5264b7d2019-10-21 14:25:41 +0100544 static size_t get_kernel_storage_size(unsigned int n_input_channels,
545 unsigned int n_output_channels);
Pablo Tello8f43d742019-03-27 09:28:32 +0000546
547 /** Get the memory required to store the input tensor transformed into
548 * the Winograd domain.
549 */
550 static size_t get_input_storage_size(
Pablo Tello5264b7d2019-10-21 14:25:41 +0100551 unsigned int n_batches, // Number of batches
552 unsigned int n_rows, // Number of input rows
553 unsigned int n_cols, // Number of input columns
554 unsigned int n_channels, // Number of input channels
555 bool padding_same);
Pablo Tello8f43d742019-03-27 09:28:32 +0000556
557 /** Get the memory required to store the output tensor in the Winograd
558 * domain.
559 */
560 static size_t get_output_storage_size(
Pablo Tello5264b7d2019-10-21 14:25:41 +0100561 unsigned int n_batches, // Number of batches
562 unsigned int n_rows, // Number of output rows
563 unsigned int n_cols, // Number of output columns
564 unsigned int n_channels // Number of output channels
565 );
Pablo Tello8f43d742019-03-27 09:28:32 +0000566
567 /** Get the memory required to apply a Winograd operator to some input.
568 */
569 static size_t get_working_space_size(
Pablo Tello5264b7d2019-10-21 14:25:41 +0100570 unsigned int n_batches,
571 unsigned int n_rows, // Number of input rows
572 unsigned int n_cols, // Number of input columns
573 unsigned int n_input_channels, // Number of input channels
574 unsigned int n_output_channels, // Number of output channels
575 bool padding_same);
Pablo Tello8f43d742019-03-27 09:28:32 +0000576
577 /* Get the memory required by a single "input" matrix.
578 */
579 static size_t get_input_matrix_size(
Pablo Tello5264b7d2019-10-21 14:25:41 +0100580 unsigned int n_batches, // Number of batches
581 unsigned int n_rows, // Number of input rows
582 unsigned int n_cols, // Number of input columns
583 unsigned int n_channels, // Number of input channels
584 bool padding_same);
Pablo Tello8f43d742019-03-27 09:28:32 +0000585
586 static int get_input_matrix_stride(
Pablo Tello5264b7d2019-10-21 14:25:41 +0100587 unsigned int n_batches, // Number of batches
588 unsigned int n_rows, // Number of input rows
589 unsigned int n_cols, // Number of input columns
590 unsigned int n_channels, // Number of input channels
591 bool padding_same);
Pablo Tello8f43d742019-03-27 09:28:32 +0000592
593 /* Get the memory required by a single "output" matrix.
594 */
595 static size_t get_output_matrix_size(
Pablo Tello5264b7d2019-10-21 14:25:41 +0100596 unsigned int n_batches, // Number of batches
597 unsigned int n_rows, // Number of output rows
598 unsigned int n_cols, // Number of output columns
599 unsigned int n_channels // Number of output channels
600 );
Pablo Tello8f43d742019-03-27 09:28:32 +0000601
602 static int get_output_matrix_stride(
Pablo Tello5264b7d2019-10-21 14:25:41 +0100603 unsigned int n_batches, // Number of batches
604 unsigned int n_rows, // Number of output rows
605 unsigned int n_cols, // Number of output columns
606 unsigned int n_channels // Number of output channels
607 );
Pablo Tello8f43d742019-03-27 09:28:32 +0000608
609 /* Get the memory required by a single "kernel" matrix.
610 */
Pablo Tello5264b7d2019-10-21 14:25:41 +0100611 static size_t get_kernel_matrix_size(unsigned int n_input_channels,
612 unsigned int n_output_channels);
613 static int get_kernel_matrix_stride(unsigned int n_input_channels,
614 unsigned int n_output_channels);
Pablo Tello8f43d742019-03-27 09:28:32 +0000615
616 static constexpr int M_BLOCK = 4; /** Size of block used by GEMM. */
617 static constexpr int N_BLOCK = 16; /** Size of block used by GEMM. */
618 };
619};
620
621} // namespace winograd