blob: 1b47853eaf8c9032bb94a6cd11c18ad5dfe048fa [file] [log] [blame]
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +00001/*
ramelg01c827e992022-04-08 03:52:28 +01002 * Copyright (c) 2021-2022 Arm Limited.
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#pragma once
26
27#include "arm_gemm_local.hpp"
28#include "pool_common.hpp"
29
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +000030namespace arm_conv
31{
32namespace pooling
33{
34struct PoolingConfig
35{
36 PoolingMethod method = PoolingMethod::DEFAULT;
37 std::string filter = "";
38
39 PoolingConfig(PoolingMethod method)
40 : method(method) {};
41 PoolingConfig() {};
42};
43
44struct PoolingArgs
45{
46 const CPUInfo *cpu_info;
47
48 PoolingType pool_type;
49 PoolingWindow pool_window;
50 PoolingStride pool_stride;
51 bool exclude_padding;
52
53 unsigned int n_batches, input_rows, input_cols, n_channels;
54 unsigned int output_rows, output_cols;
55
56 PaddingValues padding;
57
58 const PoolingConfig *config;
59
60 PoolingArgs(
61 const CPUInfo *cpu_info,
62 PoolingType pool_type,
63 const PoolingWindow &window,
64 const PoolingStride &stride,
65 bool exclude_padding,
66 unsigned int n_batches,
67 unsigned int input_rows,
68 unsigned int input_cols,
69 unsigned int n_channels,
70 unsigned int output_rows,
71 unsigned int output_cols,
72 const PaddingValues &padding,
73 const PoolingConfig *cfg)
74 : cpu_info(cpu_info), pool_type(pool_type), pool_window(window), pool_stride(stride), exclude_padding(exclude_padding), n_batches(n_batches), input_rows(input_rows), input_cols(input_cols),
75 n_channels(n_channels), output_rows(output_rows), output_cols(output_cols), padding(padding), config(cfg)
76 {
77 // If either of the pooling window dimensions are set to zero, meaning
78 // "pool everything", then replace with the corresponding input dimension.
79 if(pool_window.rows == 0)
80 {
81 pool_window.rows = input_rows;
82 }
83 if(pool_window.cols == 0)
84 {
85 pool_window.cols = input_cols;
86 }
87 }
88};
89
ramelg01c827e992022-04-08 03:52:28 +010090struct Nothing
91{
92};
93
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +000094struct Requantize32
95{
96 int32_t input_offset = 0;
97 int32_t output_offset = 0;
98
99 int32_t per_layer_left_shift = 0;
100 int32_t per_layer_right_shift = 0;
101 int32_t per_layer_mul = 0;
102
103 Requantize32(int32_t input_offset, int32_t output_offset,
104 int32_t per_layer_left_shift, int32_t per_layer_right_shift,
105 int32_t per_layer_mul)
106 : input_offset(input_offset), output_offset(output_offset), per_layer_left_shift(per_layer_left_shift), per_layer_right_shift(per_layer_right_shift), per_layer_mul(per_layer_mul)
107 {
108 }
109};
110
ramelg01c827e992022-04-08 03:52:28 +0100111template <typename TInput, typename TOutput>
112class PoolingCommon : public IPoolingCommon
113{
114protected:
115 const PoolingArgs m_args;
116
117public:
118 PoolingCommon(const PoolingArgs &args)
119 : m_args(args)
120 {
121 }
122 PoolingCommon(PoolingCommon &) = delete;
123 PoolingCommon &operator=(PoolingCommon &) = delete;
124
125 size_t get_working_size(unsigned int, unsigned int) const override = 0;
126 size_t get_working_size(unsigned int n_threads) const override
127 {
128 return this->get_working_size(n_threads, m_args.n_channels);
129 }
130
131 // Execute pooling over the specified area of memory.
132 void execute(
133 const void *const input,
134 void *const output,
135 void *working_space,
136 unsigned int thread_id,
137 unsigned int num_threads) const override
138 {
139 this->execute(
140 input,
141 m_args.n_channels,
142 m_args.n_channels * m_args.input_cols,
143 m_args.n_channels * m_args.input_cols * m_args.input_rows,
144 output,
145 m_args.n_channels,
146 m_args.n_channels * m_args.output_cols,
147 m_args.n_channels * m_args.output_cols * m_args.output_rows,
148 working_space,
149 thread_id, num_threads);
150 }
151
152 void execute(
153 const void *const input,
154 size_t ld_input_col,
155 size_t ld_input_row,
156 size_t ld_input_batch,
157 void *const output,
158 size_t ld_output_col,
159 size_t ld_output_row,
160 size_t ld_output_batch,
161 void *working_space,
162 unsigned int thread_id,
163 unsigned int num_threads) const override
164 {
165 this->execute(
166 m_args.n_batches, m_args.input_rows, m_args.input_cols, m_args.n_channels,
167 input, ld_input_col, ld_input_row, ld_input_batch,
168 m_args.padding, m_args.output_rows, m_args.output_cols,
169 output, ld_output_col, ld_output_row, ld_output_batch,
170 working_space, thread_id, num_threads);
171 }
172
173 void execute(
174 unsigned int batches,
175 unsigned int height,
176 unsigned int width,
177 unsigned int channels,
178 const void *const input,
179 size_t ld_input_col,
180 size_t ld_input_row,
181 size_t ld_input_batch,
182 const PaddingValues &padding,
183 unsigned int output_height,
184 unsigned int output_width,
185 void *const output,
186 size_t ld_output_col,
187 size_t ld_output_row,
188 size_t ld_output_batch,
189 void *working_space,
190 unsigned int thread_id,
191 unsigned int num_threads) const override
192 {
193 this->execute_internal(
194 batches, height, width, channels, padding,
195 input, ld_input_col, ld_input_row, ld_input_batch,
196 output_height, output_width,
197 output, ld_output_col, ld_output_row, ld_output_batch,
198 working_space, thread_id, num_threads);
199 }
200
201protected:
202 virtual void execute_internal(
203 unsigned int batches,
204 unsigned int height,
205 unsigned int width,
206 unsigned int channels,
207 const PaddingValues &,
208 const void *const input,
209 size_t ld_input_col,
210 size_t ld_input_row,
211 size_t ld_input_batch,
212 unsigned int output_height,
213 unsigned int output_width,
214 void *const output,
215 size_t ld_output_col,
216 size_t ld_output_row,
217 size_t ld_output_batch,
218 void *working_space,
219 unsigned int thread_id,
220 unsigned int num_threads) const = 0;
221};
222
223template <typename TInput, typename TOutput>
224using UniquePoolingCommon = std::unique_ptr<PoolingCommon<TInput, TOutput>>;
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +0000225
226// Get a pooling engine
227template <typename TInput, typename TOutput = TInput, class OutputStage = Nothing>
ramelg01c827e992022-04-08 03:52:28 +0100228UniquePoolingCommon<TInput, TOutput> pooling(const PoolingArgs &, const OutputStage & = {});
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +0000229
230} // namespace pooling
231} // namespace arm_conv