blob: 9262ea05a49c91b947edc56b27811a119dc93412 [file] [log] [blame]
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +00001/*
Pablo Marquez Tello62a3b0c2021-12-06 12:15:00 +00002 * Copyright (c) 2021-2022 Arm Limited.
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#pragma once
26
27#include "arm_gemm.hpp"
28#include "arm_gemm_local.hpp"
29#include "depthwise_common.hpp"
30
31namespace arm_conv
32{
33namespace depthwise
34{
35struct DepthwiseConfig
36{
37 DepthwiseMethod method = DepthwiseMethod::DEFAULT;
38 std::string filter = "";
39
40 DepthwiseConfig(DepthwiseMethod method)
41 : method(method) {};
42 DepthwiseConfig() {};
43};
44
45struct DepthwiseArgs
46{
47 const CPUInfo *cpu_info;
48
49 unsigned int kernel_rows, kernel_cols;
50 unsigned int stride_rows, stride_cols;
51
52 unsigned int n_batches, input_rows, input_cols, input_channels;
53 unsigned int output_rows, output_cols;
54 unsigned int channel_multiplier;
55
56 PaddingValues padding;
57
58 arm_gemm::Activation activation;
59
60 const DepthwiseConfig *config;
61
62 DepthwiseArgs(
63 const CPUInfo *cpu_info,
64 unsigned int kernel_rows, unsigned int kernel_cols,
65 unsigned int stride_rows, unsigned int stride_cols,
66 unsigned int n_batches, unsigned int input_rows, unsigned int input_cols,
67 unsigned int input_channels,
68 unsigned int output_rows, unsigned int output_cols,
69 unsigned int channel_multiplier,
70 PaddingValues padding, arm_gemm::Activation activation,
71 const DepthwiseConfig *config)
72 : cpu_info(cpu_info), kernel_rows(kernel_rows), kernel_cols(kernel_cols), stride_rows(stride_rows), stride_cols(stride_cols), n_batches(n_batches), input_rows(input_rows), input_cols(input_cols),
73 input_channels(input_channels), output_rows(output_rows), output_cols(output_cols), channel_multiplier(channel_multiplier), padding(padding), activation(activation), config(config)
74 {
75 }
76};
77
78template <typename TInput, typename TWeight, typename TOutput>
79class DepthwiseCommon : public IDepthwiseCommon
80{
Pablo Marquez Tello62a3b0c2021-12-06 12:15:00 +000081private:
82 std::string _name{};
83
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +000084protected:
85 const DepthwiseArgs m_args; // Copy of arguments
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +000086public:
Pablo Marquez Tello62a3b0c2021-12-06 12:15:00 +000087 std::string name() const
88 {
89 return _name;
90 }
91 void set_name(const std::string &n)
92 {
93 _name = n;
94 }
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +000095 DepthwiseCommon(const DepthwiseArgs &args)
96 : m_args(args) {};
97 DepthwiseCommon(DepthwiseCommon &) = delete;
98 DepthwiseCommon &operator=(DepthwiseCommon &) = delete;
99
100 void execute(
101 const void *const input,
102 const void *const parameters,
103 void *const output,
104 void *const working_space,
105 const unsigned int thread_id,
106 const unsigned int n_threads) const override
107 {
108 const size_t ld_input_col = m_args.input_channels;
109 const size_t ld_input_row = ld_input_col * m_args.input_cols;
110 const size_t ld_input_batch = ld_input_row * m_args.input_rows;
111 const size_t ld_output_col = m_args.input_channels * m_args.channel_multiplier;
112 const size_t ld_output_row = ld_output_col * m_args.output_cols;
113 const size_t ld_output_batch = ld_output_row * m_args.output_rows;
114
115 execute(
116 input, ld_input_col, ld_input_row, ld_input_batch,
117 parameters, output, ld_output_col, ld_output_row, ld_output_batch,
118 working_space, thread_id, n_threads);
119 }
120
121 void execute(
122 const void *const input,
123 size_t ld_input_col,
124 size_t ld_input_row,
125 size_t ld_input_batch,
126 const void *const parameters,
127 void *const output,
128 size_t ld_output_col,
129 size_t ld_output_row,
130 size_t ld_output_batch,
131 void *const working_space,
132 const unsigned int thread_id,
133 const unsigned int n_threads) const override
134 {
135 execute(
136 m_args.n_batches, m_args.input_rows, m_args.input_cols,
137 m_args.input_channels, m_args.padding,
138 input, ld_input_col, ld_input_row, ld_input_batch,
139 parameters,
140 m_args.output_rows, m_args.output_cols,
141 output, ld_output_col, ld_output_row, ld_output_batch,
142 working_space, thread_id, n_threads);
143 }
144
145 virtual void execute(
146 unsigned int batches,
147 unsigned int input_height,
148 unsigned int input_width,
149 unsigned int channels,
150 const PaddingValues &,
151 const void *input,
152 size_t ld_input_col,
153 size_t ld_input_row,
154 size_t ld_input_batch,
155 const void *parameters,
156 unsigned int output_height,
157 unsigned int output_width,
158 void *output,
159 size_t ld_output_col,
160 size_t ld_output_row,
161 size_t ld_output_batch,
162 void *working_space,
163 unsigned int thread_id,
164 unsigned int n_threads) const override = 0;
165};
166
167template <typename TInput, typename TWeight = TInput, typename TOutput = TInput>
168using UniqueDepthwiseCommon = std::unique_ptr<DepthwiseCommon<TInput, TWeight, TOutput>>;
169
170template <typename TInput, typename TWeight = TInput, typename TOutput = TInput, class OutputStage = Nothing>
171KernelDescription get_depthwise_method(const DepthwiseArgs &, const OutputStage & = {});
172
173template <typename TInput, typename TWeight = TInput, typename TOutput = TInput, class OutputStage = Nothing>
174UniqueDepthwiseCommon<TInput, TWeight, TOutput> depthwise(const DepthwiseArgs &, const OutputStage & = {});
175
176template <typename TInput, typename TWeight = TInput, typename TOutput = TInput, class OutputStage = Nothing>
177std::vector<KernelDescription> get_compatible_kernels(const DepthwiseArgs &, const OutputStage & = {});
178
179} // namespace depthwise
180} // namespace arm_conv