blob: 2325bd08caaa596751916c80dc1427f8513818b8 [file] [log] [blame]
Michele Di Giorgiod556d7b2020-10-27 10:56:31 +00001/*
2 * Copyright (c) 2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#pragma once
26
27#include "arm_gemm_local.hpp"
28#include "pool_common.hpp"
29
30#include <memory>
31
32namespace arm_conv
33{
34namespace pooling
35{
36struct PoolingConfig
37{
38 PoolingMethod method = PoolingMethod::DEFAULT;
39 std::string filter = "";
40
41 PoolingConfig(PoolingMethod method)
42 : method(method) {};
43 PoolingConfig() {};
44};
45
46struct PoolingArgs
47{
48 const CPUInfo *cpu_info;
49
50 PoolingType pool_type;
51 PoolingWindow pool_window;
52 PoolingStride pool_stride;
53 bool exclude_padding;
54
55 unsigned int n_batches, input_rows, input_cols, n_channels;
56 unsigned int output_rows, output_cols;
57
58 PaddingValues padding;
59
60 const PoolingConfig *config;
61
62 PoolingArgs(
63 const CPUInfo *cpu_info,
64 PoolingType pool_type,
65 const PoolingWindow &window,
66 const PoolingStride &stride,
67 bool exclude_padding,
68 unsigned int n_batches,
69 unsigned int input_rows,
70 unsigned int input_cols,
71 unsigned int n_channels,
72 unsigned int output_rows,
73 unsigned int output_cols,
74 const PaddingValues &padding,
75 const PoolingConfig *cfg)
76 : cpu_info(cpu_info), pool_type(pool_type), pool_window(window), pool_stride(stride), exclude_padding(exclude_padding), n_batches(n_batches), input_rows(input_rows), input_cols(input_cols),
77 n_channels(n_channels), output_rows(output_rows), output_cols(output_cols), padding(padding), config(cfg)
78 {
79 // If either of the pooling window dimensions are set to zero, meaning
80 // "pool everything", then replace with the corresponding input dimension.
81 if(pool_window.rows == 0)
82 {
83 pool_window.rows = input_rows;
84 }
85 if(pool_window.cols == 0)
86 {
87 pool_window.cols = input_cols;
88 }
89 }
90};
91
92struct Requantize32
93{
94 int32_t input_offset = 0;
95 int32_t output_offset = 0;
96
97 int32_t per_layer_left_shift = 0;
98 int32_t per_layer_right_shift = 0;
99 int32_t per_layer_mul = 0;
100
101 Requantize32(int32_t input_offset, int32_t output_offset,
102 int32_t per_layer_left_shift, int32_t per_layer_right_shift,
103 int32_t per_layer_mul)
104 : input_offset(input_offset), output_offset(output_offset), per_layer_left_shift(per_layer_left_shift), per_layer_right_shift(per_layer_right_shift), per_layer_mul(per_layer_mul)
105 {
106 }
107};
108
109template <typename TInput, typename TOutput, class OutputStage = Nothing>
110using UniquePoolingCommon = std::unique_ptr<PoolingCommon<TInput, TOutput, OutputStage>>;
111
112// Get a pooling engine
113template <typename TInput, typename TOutput = TInput, class OutputStage = Nothing>
114UniquePoolingCommon<TInput, TOutput, OutputStage> pooling(const PoolingArgs &, const OutputStage & = {});
115
116} // namespace pooling
117} // namespace arm_conv