blob: d05da18b581e79f3efcc2a7956c52302e539d4b6 [file] [log] [blame]
Gian Marco Iodicea5cb79f2022-12-28 13:53:51 +00001/*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef SRC_RUNTIME_HEURISTICS_INDIRECT_CONV_ICLINDIRECTCONVKERNELCONFIG
25#define SRC_RUNTIME_HEURISTICS_INDIRECT_CONV_ICLINDIRECTCONVKERNELCONFIG
26
27#include "arm_compute/core/GPUTarget.h"
28#include "arm_compute/core/KernelDescriptors.h"
29#include "arm_compute/core/Types.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010030
Gian Marco Iodicea5cb79f2022-12-28 13:53:51 +000031#include "src/core/common/Macros.h"
32
33namespace arm_compute
34{
35namespace cl_indirect_conv
36{
37/** Basic container for the OpenCL indirect convolution configuration functions */
38template <class T>
39class ClIndirectConvConfigArray
40{
41public:
42 /** Alias for F32 index */
43 static constexpr size_t DT_F32 = 0;
44 /** Alias for F16 index */
45 static constexpr size_t DT_F16 = 1;
46
47 /** Constructor
48 *
Gian Marco Iodice9d3bd412022-12-30 09:45:00 +000049 * @param[in] func_f32 Function to call for indirect convolution F32
50 * @param[in] func_f16 Function to call for indirect convolution F16
Gian Marco Iodicea5cb79f2022-12-28 13:53:51 +000051 *
52 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010053 ClIndirectConvConfigArray(T func_f32, T func_f16) : _configs{func_f32, func_f16}
Gian Marco Iodicea5cb79f2022-12-28 13:53:51 +000054 {
55 }
56
57 /** Method to return the indirect convolution configuration function based on data type
58 *
59 * @param[in] data_type Input data type
60 *
61 * @return the valid function otherwise it returns nullptr if the data type is not valid
62 */
63 T get_function(DataType data_type)
64 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010065 switch (data_type)
Gian Marco Iodicea5cb79f2022-12-28 13:53:51 +000066 {
67 case DataType::F32:
68 return _configs.at(DT_F32);
69 case DataType::F16:
70 return _configs.at(DT_F16);
71 default:
72 return nullptr;
73 }
74 }
75
76private:
77 std::array<T, 2> _configs;
78};
79
80/** Basic interface for the indirect convolution kernel configuration */
81class IClIndirectConvKernelConfig
82{
83public:
84 /** Constructor
85 *
86 * @param[in] arch GPU target
87 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010088 IClIndirectConvKernelConfig(GPUTarget arch) : _target(arch)
Gian Marco Iodicea5cb79f2022-12-28 13:53:51 +000089 {
90 }
91 ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(IClIndirectConvKernelConfig);
92 /** Virtual destructor */
93 virtual ~IClIndirectConvKernelConfig() = default;
94 /** This method returns the @ref DirectConvComputeKernelInfo for the given inputs
95 *
96 * @param[in] src Source tensor (activation tensor)
97 * @param[in] wei Weights tensor
98 * @param[in] conv_info Convolution info
99 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100100 virtual DirectConvComputeKernelInfo
101 configure(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) = 0;
Gian Marco Iodicea5cb79f2022-12-28 13:53:51 +0000102
103protected:
104 GPUTarget _target;
105};
Gian Marco Iodice9d3bd412022-12-30 09:45:00 +0000106} // namespace cl_indirect_conv
Gian Marco Iodicea5cb79f2022-12-28 13:53:51 +0000107} // namespace arm_compute
108#endif /* SRC_RUNTIME_HEURISTICS_INDIRECT_CONV_ICLINDIRECTCONVKERNELCONFIG */