blob: 3998dfbc9a4adc8b7d16b42f4e14297f9a36f982 [file] [log] [blame]
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +00001/*
Pablo Marquez Tello62a3b0c2021-12-06 12:15:00 +00002 * Copyright (c) 2021-2022 Arm Limited.
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#pragma once
26
27#include "arm_gemm.hpp"
28#include "arm_gemm_local.hpp"
29#include "depthwise_common.hpp"
30
31namespace arm_conv
32{
33namespace depthwise
34{
35struct DepthwiseConfig
36{
37 DepthwiseMethod method = DepthwiseMethod::DEFAULT;
38 std::string filter = "";
39
40 DepthwiseConfig(DepthwiseMethod method)
41 : method(method) {};
42 DepthwiseConfig() {};
43};
44
45struct DepthwiseArgs
46{
47 const CPUInfo *cpu_info;
48
49 unsigned int kernel_rows, kernel_cols;
50 unsigned int stride_rows, stride_cols;
51
52 unsigned int n_batches, input_rows, input_cols, input_channels;
53 unsigned int output_rows, output_cols;
54 unsigned int channel_multiplier;
55
56 PaddingValues padding;
57
58 arm_gemm::Activation activation;
59
60 const DepthwiseConfig *config;
61
ramelg018a164882022-04-07 02:42:52 +010062 bool fast_mode = false;
63
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +000064 DepthwiseArgs(
65 const CPUInfo *cpu_info,
66 unsigned int kernel_rows, unsigned int kernel_cols,
67 unsigned int stride_rows, unsigned int stride_cols,
68 unsigned int n_batches, unsigned int input_rows, unsigned int input_cols,
69 unsigned int input_channels,
70 unsigned int output_rows, unsigned int output_cols,
71 unsigned int channel_multiplier,
72 PaddingValues padding, arm_gemm::Activation activation,
73 const DepthwiseConfig *config)
74 : cpu_info(cpu_info), kernel_rows(kernel_rows), kernel_cols(kernel_cols), stride_rows(stride_rows), stride_cols(stride_cols), n_batches(n_batches), input_rows(input_rows), input_cols(input_cols),
75 input_channels(input_channels), output_rows(output_rows), output_cols(output_cols), channel_multiplier(channel_multiplier), padding(padding), activation(activation), config(config)
76 {
77 }
78};
79
80template <typename TInput, typename TWeight, typename TOutput>
81class DepthwiseCommon : public IDepthwiseCommon
82{
Pablo Marquez Tello62a3b0c2021-12-06 12:15:00 +000083private:
84 std::string _name{};
85
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +000086protected:
87 const DepthwiseArgs m_args; // Copy of arguments
ramelg018a164882022-04-07 02:42:52 +010088
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +000089public:
Pablo Marquez Tello62a3b0c2021-12-06 12:15:00 +000090 std::string name() const
91 {
92 return _name;
93 }
ramelg018a164882022-04-07 02:42:52 +010094
Pablo Marquez Tello62a3b0c2021-12-06 12:15:00 +000095 void set_name(const std::string &n)
96 {
97 _name = n;
98 }
ramelg018a164882022-04-07 02:42:52 +010099
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +0000100 DepthwiseCommon(const DepthwiseArgs &args)
101 : m_args(args) {};
102 DepthwiseCommon(DepthwiseCommon &) = delete;
103 DepthwiseCommon &operator=(DepthwiseCommon &) = delete;
104
105 void execute(
106 const void *const input,
107 const void *const parameters,
108 void *const output,
109 void *const working_space,
110 const unsigned int thread_id,
ramelg018a164882022-04-07 02:42:52 +0100111 const unsigned int n_threads) const override final
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +0000112 {
113 const size_t ld_input_col = m_args.input_channels;
114 const size_t ld_input_row = ld_input_col * m_args.input_cols;
115 const size_t ld_input_batch = ld_input_row * m_args.input_rows;
116 const size_t ld_output_col = m_args.input_channels * m_args.channel_multiplier;
117 const size_t ld_output_row = ld_output_col * m_args.output_cols;
118 const size_t ld_output_batch = ld_output_row * m_args.output_rows;
119
120 execute(
121 input, ld_input_col, ld_input_row, ld_input_batch,
122 parameters, output, ld_output_col, ld_output_row, ld_output_batch,
123 working_space, thread_id, n_threads);
124 }
125
126 void execute(
127 const void *const input,
128 size_t ld_input_col,
129 size_t ld_input_row,
130 size_t ld_input_batch,
131 const void *const parameters,
132 void *const output,
133 size_t ld_output_col,
134 size_t ld_output_row,
135 size_t ld_output_batch,
136 void *const working_space,
137 const unsigned int thread_id,
ramelg018a164882022-04-07 02:42:52 +0100138 const unsigned int n_threads) const override final
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +0000139 {
140 execute(
141 m_args.n_batches, m_args.input_rows, m_args.input_cols,
142 m_args.input_channels, m_args.padding,
143 input, ld_input_col, ld_input_row, ld_input_batch,
144 parameters,
145 m_args.output_rows, m_args.output_cols,
146 output, ld_output_col, ld_output_row, ld_output_batch,
147 working_space, thread_id, n_threads);
148 }
149
ramelg018a164882022-04-07 02:42:52 +0100150 void execute(
151 unsigned int batches,
152 unsigned int input_height,
153 unsigned int input_width,
154 unsigned int channels,
155 const PaddingValues &padding,
156 const void *input,
157 size_t ld_input_col,
158 size_t ld_input_row,
159 size_t ld_input_batch,
160 const void *parameters,
161 unsigned int output_height,
162 unsigned int output_width,
163 void *output,
164 size_t ld_output_col,
165 size_t ld_output_row,
166 size_t ld_output_batch,
167 void *working_space,
168 unsigned int thread_id,
169 unsigned int n_threads) const override final
170 {
171 this->execute_internal(
172 batches, input_height, input_width, channels, padding, input,
173 ld_input_col, ld_input_row, ld_input_batch, parameters, output_height,
174 output_width, output, ld_output_col, ld_output_row, ld_output_batch,
175 working_space, thread_id, n_threads);
176 }
177
178protected:
179 virtual void execute_internal(
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +0000180 unsigned int batches,
181 unsigned int input_height,
182 unsigned int input_width,
183 unsigned int channels,
184 const PaddingValues &,
185 const void *input,
186 size_t ld_input_col,
187 size_t ld_input_row,
188 size_t ld_input_batch,
189 const void *parameters,
190 unsigned int output_height,
191 unsigned int output_width,
192 void *output,
193 size_t ld_output_col,
194 size_t ld_output_row,
195 size_t ld_output_batch,
196 void *working_space,
197 unsigned int thread_id,
ramelg018a164882022-04-07 02:42:52 +0100198 unsigned int n_threads) const = 0;
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +0000199};
200
201template <typename TInput, typename TWeight = TInput, typename TOutput = TInput>
202using UniqueDepthwiseCommon = std::unique_ptr<DepthwiseCommon<TInput, TWeight, TOutput>>;
203
204template <typename TInput, typename TWeight = TInput, typename TOutput = TInput, class OutputStage = Nothing>
205KernelDescription get_depthwise_method(const DepthwiseArgs &, const OutputStage & = {});
206
207template <typename TInput, typename TWeight = TInput, typename TOutput = TInput, class OutputStage = Nothing>
208UniqueDepthwiseCommon<TInput, TWeight, TOutput> depthwise(const DepthwiseArgs &, const OutputStage & = {});
209
210template <typename TInput, typename TWeight = TInput, typename TOutput = TInput, class OutputStage = Nothing>
211std::vector<KernelDescription> get_compatible_kernels(const DepthwiseArgs &, const OutputStage & = {});
212
213} // namespace depthwise
214} // namespace arm_conv