blob: c33e1897970f2c73c493c4563573707929321b23 [file] [log] [blame]
SiCong Lib63b1192022-01-28 18:24:39 +00001/*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION
25#error "This experimental feature must be enabled with -DENABLE_EXPERIMENTAL_DYNAMIC_FUSION"
26#endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */
27#ifndef ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_OPERATORGRAPHIMPL
28#define ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_OPERATORGRAPHIMPL
29
30#include "arm_compute/core/experimental/ClWorkload.h"
31#include "src/core/experimental/dynamic_fusion/WorkloadImpl/ITensorDescPack.h"
32
33#include "support/Cast.h"
34#include "support/DeepCopy.h"
35
36#include <map>
37#include <tuple>
38#include <type_traits>
39
40namespace arm_compute
41{
42namespace experimental
43{
44namespace dynamic_fusion
45{
46enum class OperatorComplexity
47{
48 Complex = 0,
49 Simple
50};
51
52struct ClKernelGraph;
53struct OpTensorContent
54{
55public:
56 using Id = DependencyGraph::Id;
57 OpTensorContent() = default;
58 OpTensorContent(Id id)
59 : id{ id }, desc{}
60 {
61 }
62 OpTensorContent(Id id, ITensorInfo *desc)
63 : id{ id }, desc{ desc }
64 {
65 }
66 ~OpTensorContent() = default;
67 OpTensorContent(const OpTensorContent &) = default;
68 OpTensorContent &operator=(const OpTensorContent &) = default;
69 OpTensorContent(OpTensorContent &&) = default;
70 OpTensorContent &operator=(OpTensorContent &&) = default;
71 bool operator==(const OpTensorContent &other) const
72 {
73 return desc == other.desc;
74 }
75
76 const ITensorInfo *get_tensor_info() const
77 {
78 return desc;
79 }
80 ITensorInfo *get_tensor_info()
81 {
82 return desc;
83 }
84
85 Id id{};
86 ITensorInfo *desc{};
87};
88
89struct OperatorContent
90{
91public:
92 using Id = DependencyGraph::Id;
93 OperatorContent() = default;
94 OperatorContent(const OperatorGraph::Implementation *graph, Id id, const ITensorDescPack<OpTensorContent> &tensors)
95 : _graph{ graph }, _id{ id }, _tensors{ tensors }
96 {
97 }
98 OperatorContent(const OperatorContent &op) = default;
99 OperatorContent &operator=(const OperatorContent &op) = default;
100 OperatorContent(OperatorContent &&op) = default;
101 OperatorContent &operator=(OperatorContent &&op) = default;
102 virtual ~OperatorContent() = default;
103 virtual OperatorComplexity complexity() const = 0;
104 virtual bool operator==(const OperatorContent &other) const = 0;
105 virtual Status translate(ClKernelGraph &kernel_graph) const = 0;
106
107protected:
108 const OperatorGraph::Implementation *_graph {};
109 Id _id{};
110 ITensorDescPack<OpTensorContent> _tensors{};
111};
112
113struct Conv2dContent : public OperatorContent
114{
115public:
116 Conv2dContent() = default;
117 Conv2dContent(const OperatorGraph::Implementation *graph, Id id, const Conv2dDescriptor &desc, const ITensorDescPack<OpTensorContent> &tensors)
118 : OperatorContent(graph, id, tensors), desc(desc), forced_method(), forced_method_enabled(false)
119 {
120 }
121 // Temporary. Do not need to pass ConvolutionMethod
122 Conv2dContent(const OperatorGraph::Implementation *graph, Id id, const Conv2dDescriptor &desc, const ITensorDescPack<OpTensorContent> &tensors, ConvolutionMethod method)
123 : OperatorContent(graph, id, tensors), desc(desc), forced_method(method), forced_method_enabled(true)
124 {
125 }
126 ~Conv2dContent() = default;
127 Conv2dContent(const Conv2dContent &) = default;
128 Conv2dContent &operator=(const Conv2dContent &) = default;
129 Conv2dContent(Conv2dContent &&) = default;
130 Conv2dContent &operator=(Conv2dContent &&) = default;
131 bool operator==(const OperatorContent &other) const override;
132 OperatorComplexity complexity() const override
133 {
134 return OperatorComplexity::Complex;
135 }
136 void set_method(ConvolutionMethod method)
137 {
138 forced_method_enabled = true;
139 forced_method = method;
140 }
141
142 Status translate(ClKernelGraph &kernel_graph) const override;
143 /** Replicate heuristics of @ref ClConv2d::get_convolution_method(), except that non-supported data types and data layouts are removed from the heuristics
144 *
145 * @param src
146 * @param weights
147 * @param dst
148 * @param conv2d_desc
149 * @param gpu_target
150 * @return ConvolutionMethod
151 */
152 static ConvolutionMethod select_conv_method(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const Conv2dDescriptor &conv2d_desc, const GPUTarget gpu_target);
153
154 Conv2dDescriptor desc{};
155 ConvolutionMethod forced_method{ ConvolutionMethod::GEMM_CONV2D };
156 bool forced_method_enabled{ false };
157
158private:
159 Status translate_direct_conv2d(ClKernelGraph &kernel_graph) const;
160};
161
162class AddContent : public OperatorContent
163{
164public:
165 AddContent() = default;
166 AddContent(const OperatorGraph::Implementation *graph, Id id, const AddDescriptor &desc, const ITensorDescPack<OpTensorContent> &tensors)
167 : OperatorContent(graph, id, tensors), desc(desc)
168 {
169 }
170 ~AddContent() = default;
171 AddContent(const AddContent &) = default;
172 AddContent &operator=(const AddContent &) = default;
173 AddContent(AddContent &&) = default;
174 AddContent &operator=(AddContent &&) = default;
175 bool operator==(const OperatorContent &other) const override;
176 OperatorComplexity complexity() const override
177 {
178 return OperatorComplexity::Simple;
179 }
180 Status translate(ClKernelGraph &kernel_graph) const override;
181
182private:
183 AddDescriptor desc{};
184};
185
186struct OperatorGraph::Implementation
187{
188public:
189 template <typename ContentT, typename... Args>
190 void add_node(Operator::Id id, Args &&... args)
191 {
192 operators[id] = utils::memory::make_deep_unique<OperatorContent, ContentT>(this, id, std::forward<Args>(args)...);
193 }
194
195 template <typename... Args>
196 void add_tensor(OpTensor::Id id, Args &&... args)
197 {
198 tensors[id] = utils::memory::make_deep_unique<OpTensorContent, OpTensorContent>(id, std::forward<Args>(args)...);
199 }
200
201 using Dependency = DependencyGraph;
202 using OperatorMap = std::map<Operator::Id, utils::memory::deep_unique_ptr<OperatorContent>>;
203 using OpTensorMap = std::map<OpTensor::Id, utils::memory::deep_unique_ptr<OpTensorContent>>;
204
205 Implementation() = default;
206 ~Implementation() = default;
207
208 friend bool operator==(const OperatorGraph::Implementation &graph0, const OperatorGraph::Implementation &graph1)
209 {
210 return graph0.graph == graph1.graph && graph0.operators == graph1.operators && graph0.tensors == graph1.tensors;
211 }
212
213 Dependency graph{};
214 OperatorMap operators{};
215 OpTensorMap tensors{};
216 Status status{};
217};
218
219std::vector<const OperatorContent *> traverse(const OperatorGraph::Implementation &graph);
220
221std::vector<OperatorContent *> traverse(OperatorGraph::Implementation &graph);
222
223Status translate(ClKernelGraph &kernel_graph, const OperatorGraph::Implementation &op_graph);
224
225} // namespace dynamic_fusion
226} // namespace experimental
227} // namespace arm_compute
228
229#endif //ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_OPERATORGRAPHIMPL