blob: d2f4cde662e96de20ef584b9947b36fa8f4d79b7 [file] [log] [blame]
Gian Marco Iodicea5cb79f2022-12-28 13:53:51 +00001/*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef SRC_RUNTIME_HEURISTICS_INDIRECT_CONV_ICLINDIRECTCONVKERNELCONFIG
25#define SRC_RUNTIME_HEURISTICS_INDIRECT_CONV_ICLINDIRECTCONVKERNELCONFIG
26
27#include "arm_compute/core/GPUTarget.h"
28#include "arm_compute/core/KernelDescriptors.h"
29#include "arm_compute/core/Types.h"
30#include "src/core/common/Macros.h"
31
32namespace arm_compute
33{
34namespace cl_indirect_conv
35{
36/** Basic container for the OpenCL indirect convolution configuration functions */
37template <class T>
38class ClIndirectConvConfigArray
39{
40public:
41 /** Alias for F32 index */
42 static constexpr size_t DT_F32 = 0;
43 /** Alias for F16 index */
44 static constexpr size_t DT_F16 = 1;
45
46 /** Constructor
47 *
Gian Marco Iodice9d3bd412022-12-30 09:45:00 +000048 * @param[in] func_f32 Function to call for indirect convolution F32
49 * @param[in] func_f16 Function to call for indirect convolution F16
Gian Marco Iodicea5cb79f2022-12-28 13:53:51 +000050 *
51 */
52 ClIndirectConvConfigArray(T func_f32, T func_f16)
53 : _configs{ func_f32, func_f16}
54 {
55 }
56
57 /** Method to return the indirect convolution configuration function based on data type
58 *
59 * @param[in] data_type Input data type
60 *
61 * @return the valid function otherwise it returns nullptr if the data type is not valid
62 */
63 T get_function(DataType data_type)
64 {
65 switch(data_type)
66 {
67 case DataType::F32:
68 return _configs.at(DT_F32);
69 case DataType::F16:
70 return _configs.at(DT_F16);
71 default:
72 return nullptr;
73 }
74 }
75
76private:
77 std::array<T, 2> _configs;
78};
79
80/** Basic interface for the indirect convolution kernel configuration */
81class IClIndirectConvKernelConfig
82{
83public:
84 /** Constructor
85 *
86 * @param[in] arch GPU target
87 */
88 IClIndirectConvKernelConfig(GPUTarget arch)
89 : _target(arch)
90 {
91 }
92 ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(IClIndirectConvKernelConfig);
93 /** Virtual destructor */
94 virtual ~IClIndirectConvKernelConfig() = default;
95 /** This method returns the @ref DirectConvComputeKernelInfo for the given inputs
96 *
97 * @param[in] src Source tensor (activation tensor)
98 * @param[in] wei Weights tensor
99 * @param[in] conv_info Convolution info
100 */
101 virtual DirectConvComputeKernelInfo configure(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) = 0;
102
103protected:
104 GPUTarget _target;
105};
Gian Marco Iodice9d3bd412022-12-30 09:45:00 +0000106} // namespace cl_indirect_conv
Gian Marco Iodicea5cb79f2022-12-28 13:53:51 +0000107} // namespace arm_compute
108#endif /* SRC_RUNTIME_HEURISTICS_INDIRECT_CONV_ICLINDIRECTCONVKERNELCONFIG */