blob: 89d594298ea0bb343e2e5ed6586e95daaa5a96f3 [file] [log] [blame]
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +00001/*
Michael Tyler8deee9b2023-06-30 11:26:05 +01002 * Copyright (c) 2021-2023 Arm Limited.
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#pragma once
26
27#include "arm_gemm_local.hpp"
28#include "pool_common.hpp"
29
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +000030namespace arm_conv
31{
32namespace pooling
33{
34struct PoolingConfig
35{
36 PoolingMethod method = PoolingMethod::DEFAULT;
37 std::string filter = "";
38
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010039 PoolingConfig(PoolingMethod method) : method(method){};
40 PoolingConfig(){};
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +000041};
42
43struct PoolingArgs
44{
45 const CPUInfo *cpu_info;
46
47 PoolingType pool_type;
48 PoolingWindow pool_window;
49 PoolingStride pool_stride;
50 bool exclude_padding;
51
52 unsigned int n_batches, input_rows, input_cols, n_channels;
53 unsigned int output_rows, output_cols;
54
55 PaddingValues padding;
56
57 const PoolingConfig *config;
58
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010059 PoolingArgs(const CPUInfo *cpu_info,
60 PoolingType pool_type,
61 const PoolingWindow &window,
62 const PoolingStride &stride,
63 bool exclude_padding,
64 unsigned int n_batches,
65 unsigned int input_rows,
66 unsigned int input_cols,
67 unsigned int n_channels,
68 unsigned int output_rows,
69 unsigned int output_cols,
70 const PaddingValues &padding,
71 const PoolingConfig *cfg)
72 : cpu_info(cpu_info),
73 pool_type(pool_type),
74 pool_window(window),
75 pool_stride(stride),
76 exclude_padding(exclude_padding),
77 n_batches(n_batches),
78 input_rows(input_rows),
79 input_cols(input_cols),
80 n_channels(n_channels),
81 output_rows(output_rows),
82 output_cols(output_cols),
83 padding(padding),
84 config(cfg)
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +000085 {
86 // If either of the pooling window dimensions are set to zero, meaning
87 // "pool everything", then replace with the corresponding input dimension.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010088 if (pool_window.rows == 0)
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +000089 {
90 pool_window.rows = input_rows;
91 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010092 if (pool_window.cols == 0)
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +000093 {
94 pool_window.cols = input_cols;
95 }
96 }
97};
98
ramelg01c827e992022-04-08 03:52:28 +010099struct Nothing
100{
101};
102
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +0000103struct Requantize32
104{
105 int32_t input_offset = 0;
106 int32_t output_offset = 0;
107
108 int32_t per_layer_left_shift = 0;
109 int32_t per_layer_right_shift = 0;
110 int32_t per_layer_mul = 0;
111
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100112 Requantize32(int32_t input_offset,
113 int32_t output_offset,
114 int32_t per_layer_left_shift,
115 int32_t per_layer_right_shift,
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +0000116 int32_t per_layer_mul)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100117 : input_offset(input_offset),
118 output_offset(output_offset),
119 per_layer_left_shift(per_layer_left_shift),
120 per_layer_right_shift(per_layer_right_shift),
121 per_layer_mul(per_layer_mul)
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +0000122 {
123 }
124};
125
ramelg01c827e992022-04-08 03:52:28 +0100126template <typename TInput, typename TOutput>
127class PoolingCommon : public IPoolingCommon
128{
129protected:
130 const PoolingArgs m_args;
131
132public:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100133 PoolingCommon(const PoolingArgs &args) : m_args(args)
ramelg01c827e992022-04-08 03:52:28 +0100134 {
135 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100136 PoolingCommon(PoolingCommon &) = delete;
ramelg01c827e992022-04-08 03:52:28 +0100137 PoolingCommon &operator=(PoolingCommon &) = delete;
138
Michael Tyler8deee9b2023-06-30 11:26:05 +0100139 size_t get_working_size(unsigned int) const override = 0;
ramelg01c827e992022-04-08 03:52:28 +0100140
141 // Execute pooling over the specified area of memory.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100142 void execute(const void *const input,
143 void *const output,
144 void *working_space,
145 unsigned int thread_id,
146 unsigned int num_threads) const override
ramelg01c827e992022-04-08 03:52:28 +0100147 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100148 this->execute(input, m_args.n_channels, m_args.n_channels * m_args.input_cols,
149 m_args.n_channels * m_args.input_cols * m_args.input_rows, output, m_args.n_channels,
150 m_args.n_channels * m_args.output_cols,
151 m_args.n_channels * m_args.output_cols * m_args.output_rows, working_space, thread_id,
152 num_threads);
ramelg01c827e992022-04-08 03:52:28 +0100153 }
154
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100155 void execute(const void *const input,
156 size_t ld_input_col,
157 size_t ld_input_row,
158 size_t ld_input_batch,
159 void *const output,
160 size_t ld_output_col,
161 size_t ld_output_row,
162 size_t ld_output_batch,
163 void *working_space,
164 unsigned int thread_id,
165 unsigned int num_threads) const override
ramelg01c827e992022-04-08 03:52:28 +0100166 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100167 this->execute(m_args.n_batches, m_args.input_rows, m_args.input_cols, m_args.n_channels, input, ld_input_col,
168 ld_input_row, ld_input_batch, m_args.padding, m_args.output_rows, m_args.output_cols, output,
169 ld_output_col, ld_output_row, ld_output_batch, working_space, thread_id, num_threads);
ramelg01c827e992022-04-08 03:52:28 +0100170 }
171
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100172 void execute(unsigned int batches,
173 unsigned int height,
174 unsigned int width,
175 unsigned int channels,
176 const void *const input,
177 size_t ld_input_col,
178 size_t ld_input_row,
179 size_t ld_input_batch,
180 const PaddingValues &padding,
181 unsigned int output_height,
182 unsigned int output_width,
183 void *const output,
184 size_t ld_output_col,
185 size_t ld_output_row,
186 size_t ld_output_batch,
187 void *working_space,
188 unsigned int thread_id,
189 unsigned int num_threads) const override
ramelg01c827e992022-04-08 03:52:28 +0100190 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100191 this->execute_internal(batches, height, width, channels, padding, input, ld_input_col, ld_input_row,
192 ld_input_batch, output_height, output_width, output, ld_output_col, ld_output_row,
193 ld_output_batch, working_space, thread_id, num_threads);
ramelg01c827e992022-04-08 03:52:28 +0100194 }
195
196protected:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100197 virtual void execute_internal(unsigned int batches,
198 unsigned int height,
199 unsigned int width,
200 unsigned int channels,
201 const PaddingValues &,
202 const void *const input,
203 size_t ld_input_col,
204 size_t ld_input_row,
205 size_t ld_input_batch,
206 unsigned int output_height,
207 unsigned int output_width,
208 void *const output,
209 size_t ld_output_col,
210 size_t ld_output_row,
211 size_t ld_output_batch,
212 void *working_space,
213 unsigned int thread_id,
214 unsigned int num_threads) const = 0;
ramelg01c827e992022-04-08 03:52:28 +0100215};
216
217template <typename TInput, typename TOutput>
218using UniquePoolingCommon = std::unique_ptr<PoolingCommon<TInput, TOutput>>;
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +0000219
220// Get a pooling engine
221template <typename TInput, typename TOutput = TInput, class OutputStage = Nothing>
ramelg01c827e992022-04-08 03:52:28 +0100222UniquePoolingCommon<TInput, TOutput> pooling(const PoolingArgs &, const OutputStage & = {});
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +0000223
224} // namespace pooling
225} // namespace arm_conv