blob: dbf95d23cd263601a7719df5e843f96c89e306cb [file] [log] [blame]
ramelg01a1f78512022-06-29 16:28:10 +01001/*
Michael Tyler74921ee2023-04-12 17:43:17 +01002 * Copyright (c) 2022-2023 Arm Limited.
ramelg01a1f78512022-06-29 16:28:10 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#pragma once
26
Michael Tyler74921ee2023-04-12 17:43:17 +010027#include "arm_gemm.hpp"
ramelg01a1f78512022-06-29 16:28:10 +010028#include <cstddef>
29
30namespace arm_conv
31{
32struct Shape2D
33{
34 unsigned int rows, cols;
35};
36
37struct ConvolutionArgs
38{
39 unsigned int n_batches;
40 Shape2D input_shape;
41 unsigned int n_input_channels;
42 unsigned int pad_top, pad_left;
43 Shape2D output_shape;
44 unsigned int n_output_channels;
45 Shape2D kernel_shape;
46 arm_gemm::Activation activation;
47
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010048 ConvolutionArgs(unsigned int n_batches,
49 const Shape2D &input_shape,
50 unsigned int n_input_channels,
51 unsigned int pad_top,
52 unsigned int pad_left,
53 const Shape2D &output_shape,
54 unsigned int n_output_channels,
55 const Shape2D kernel_shape,
56 const arm_gemm::Activation &activation = {})
57 : n_batches(n_batches),
58 input_shape(input_shape),
59 n_input_channels(n_input_channels),
60 pad_top(pad_top),
61 pad_left(pad_left),
62 output_shape(output_shape),
63 n_output_channels(n_output_channels),
64 kernel_shape(kernel_shape),
65 activation(activation)
ramelg01a1f78512022-06-29 16:28:10 +010066 {
67 }
68};
69
70namespace winograd
71{
72/* Constrain the selected Winograd implementation.
73 */
74struct WinogradConfig
75{
76 unsigned int output_rows = 0, output_cols = 0;
77 std::string input_transform_filter = "";
78 std::string output_transform_filter = "";
79 std::string weight_transform_filter = "";
80};
81
82/* Struct describing (suggested) memory layout within the Winograd domain.
83 */
84struct WinogradDomainSpec
85{
86 size_t weight_matrix_size_bytes, input_matrix_size_bytes, output_matrix_size_bytes;
87
88 size_t weight_ld_matrix, weight_ld_row;
89 size_t input_ld_batch, input_ld_matrix, input_ld_row;
90 size_t output_ld_batch, output_ld_matrix, output_ld_row;
91};
92
93class ITransformCommon
94{
95public:
96 virtual ~ITransformCommon() = default;
97
98 // Get the name of the transform
99 virtual const std::string &get_name(void) const = 0;
100};
101
102namespace weight_transform
103{
104class ITransform : public ITransformCommon
105{
106public:
107 ~ITransform() = default;
108
109 virtual unsigned int get_kernel_rows(void) const = 0;
110 virtual unsigned int get_kernel_cols(void) const = 0;
111
112 virtual unsigned int get_transformed_tile_rows(void) const = 0;
113 virtual unsigned int get_transformed_tile_cols(void) const = 0;
114
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100115 void execute(const ConvolutionArgs &args,
116 const void *inptr,
117 size_t ld_in_row,
118 size_t ld_in_col,
119 size_t ld_input_channel,
120 void *outptr,
121 const WinogradDomainSpec &wds,
122 unsigned int thread_id,
123 unsigned int n_threads) const
ramelg01a1f78512022-06-29 16:28:10 +0100124 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100125 this->execute(args, inptr, ld_in_row, ld_in_col, ld_input_channel, outptr, wds.weight_ld_matrix,
126 wds.weight_ld_row, thread_id, n_threads);
ramelg01a1f78512022-06-29 16:28:10 +0100127 }
128
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100129 virtual void execute(const ConvolutionArgs &args,
130 const void *inptr,
131 size_t ld_in_row,
132 size_t ld_in_col,
133 size_t ld_input_channel,
134 void *outptr,
135 size_t ld_out_matrix,
136 size_t ld_out_row,
137 unsigned int thread_id,
138 unsigned int n_threads) const = 0;
ramelg01a1f78512022-06-29 16:28:10 +0100139};
140
141} // namespace weight_transform
142
143namespace input_transform
144{
145class ITransform : public ITransformCommon
146{
147public:
148 ~ITransform() = default;
149
150 virtual unsigned int get_input_rows(void) const = 0;
151 virtual unsigned int get_input_cols(void) const = 0;
152
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100153 virtual size_t get_working_space_size(const ConvolutionArgs &args, unsigned int n_threads) const = 0;
ramelg01a1f78512022-06-29 16:28:10 +0100154
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100155 void execute(const ConvolutionArgs &args,
156 const void *inptr,
157 size_t ld_in_batch,
158 size_t ld_in_row,
159 size_t ld_in_col,
160 void *outptr,
161 const WinogradDomainSpec &wds,
162 void *working_space,
163 unsigned int thread_id,
164 unsigned int n_threads) const
ramelg01a1f78512022-06-29 16:28:10 +0100165 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100166 this->execute(args, inptr, ld_in_batch, ld_in_row, ld_in_col, outptr, wds.input_ld_batch, wds.input_ld_matrix,
167 wds.input_ld_row, working_space, thread_id, n_threads);
ramelg01a1f78512022-06-29 16:28:10 +0100168 }
169
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100170 virtual void execute(const ConvolutionArgs &args,
171 const void *inptr,
172 size_t ld_in_batch,
173 size_t ld_in_row,
174 size_t ld_in_col,
175 void *outptr,
176 size_t ld_out_batch,
177 size_t ld_out_matrix,
178 size_t ld_out_row,
179 void *working_space,
180 unsigned int thread_id,
181 unsigned int n_threads) const = 0;
ramelg01a1f78512022-06-29 16:28:10 +0100182};
183
184} // namespace input_transform
185
186namespace output_transform
187{
188class ITransform : public ITransformCommon
189{
190public:
191 ~ITransform() = default;
192
193 virtual unsigned int get_input_rows(void) const = 0;
194 virtual unsigned int get_input_cols(void) const = 0;
195
196 virtual unsigned int get_output_rows(void) const = 0;
197 virtual unsigned int get_output_cols(void) const = 0;
198
199 virtual unsigned int get_kernel_rows(void) const = 0;
200 virtual unsigned int get_kernel_cols(void) const = 0;
201
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100202 virtual size_t get_working_space_size(const ConvolutionArgs &args, unsigned int n_threads) const = 0;
ramelg01a1f78512022-06-29 16:28:10 +0100203
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100204 void execute(const ConvolutionArgs &args,
205 const void *inptr,
206 const WinogradDomainSpec &wds,
207 const void *bias,
208 void *outptr,
209 size_t ld_out_batch,
210 size_t ld_out_row,
211 size_t ld_out_col,
212 void *working_space,
213 unsigned int thread_id,
214 unsigned int n_threads) const
ramelg01a1f78512022-06-29 16:28:10 +0100215 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100216 this->execute(args, inptr, wds.output_ld_batch, wds.output_ld_matrix, wds.output_ld_row, bias, outptr,
217 ld_out_batch, ld_out_row, ld_out_col, working_space, thread_id, n_threads);
ramelg01a1f78512022-06-29 16:28:10 +0100218 }
219
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100220 virtual void execute(const ConvolutionArgs &args,
221 const void *inptr,
222 size_t ld_in_batch,
223 size_t ld_in_matrix,
224 size_t ld_in_row,
225 const void *bias,
226 void *outptr,
227 size_t ld_out_batch,
228 size_t ld_out_row,
229 size_t ld_out_col,
230 void *working_space,
231 unsigned int thread_id,
232 unsigned int n_threads) const = 0;
ramelg01a1f78512022-06-29 16:28:10 +0100233};
234
235} // namespace output_transform
236
237struct WinogradImpl
238{
239 const output_transform::ITransform *output_transform = nullptr;
240 const weight_transform::ITransform *weight_transform = nullptr;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100241 const input_transform::ITransform *input_transform = nullptr;
ramelg01a1f78512022-06-29 16:28:10 +0100242 std::unique_ptr<arm_gemm::GemmArgs> gemm_args;
243 WinogradDomainSpec winograd_spec;
244};
245
246/* Get pointers to Winograd transforms for the given convolution problem.
247 *
248 * Assigns to the pointers in the `dest` struct and returns true or false to
249 * indicate whether the given problem can be executed or not.
250 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100251template <typename TIn,
252 typename TWeight = TIn,
253 typename TOut = TIn,
254 typename TWinogradIn = TIn,
255 typename TWinogradOut = TOut>
256bool get_implementation(WinogradImpl &dest, // Destination for the selected implementation
257 const CPUInfo *,
258 const ConvolutionArgs &,
259 int max_threads,
260 bool fast_mode,
261 const WinogradConfig *,
262 const arm_gemm::GemmConfig *);
ramelg01a1f78512022-06-29 16:28:10 +0100263
264} // namespace winograd
265} // namespace arm_conv