blob: 54e01ea8501c428ea49f80e415a04f4fb67d74bb [file] [log] [blame]
SiCong Lib63b1192022-01-28 18:24:39 +00001/*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
SiCong Li4e9f5682022-05-10 10:15:59 +010024#ifdef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION
SiCong Lib63b1192022-01-28 18:24:39 +000025#ifndef ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_CLKERNELGRAPH_H
26#define ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_CLKERNELGRAPH_H
27
28#include "arm_compute/core/TensorInfo.h"
29#include "arm_compute/core/Validate.h"
30#include "arm_compute/core/experimental/ClWorkload.h"
31#include "arm_compute/core/experimental/DependencyGraph.h"
32#include "src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelDescriptors.h"
33#include "src/core/experimental/dynamic_fusion/WorkloadImpl/ITensorDescPack.h"
34#include "support/DeepCopy.h"
35
36namespace arm_compute
37{
38namespace experimental
39{
40namespace dynamic_fusion
41{
42struct ClKernelGraph;
43class ClKernelBlueprint;
44
45enum class Complexity
46{
47 Simple,
48 Complex
49};
50
51/** Configurations for ClKernel
52 *
53 */
54struct ClKernelConfig
55{
56 UnitWorkloadStage stage{};
57 TileDescriptor tile_desc{};
58 StoreType store_type{};
59 friend bool operator==(const ClKernelConfig &config0, const ClKernelConfig &config1)
60 {
61 return config0.stage == config1.stage && config0.tile_desc == config1.tile_desc && config0.store_type == config1.store_type;
62 }
63};
64
65struct ClKernelTensor
66{
67public:
68 using Id = DependencyGraph::Id;
69 ClKernelTensor() = default;
70 ClKernelTensor(Id id, ITensorInfo *desc, MemoryType memory_type, const AuxMemoryInfo &memory_info)
71 : id{ id }, desc{ desc }, memory_type{ memory_type }, memory_info{ memory_info }
72 {
73 }
74 bool operator==(const ClKernelTensor &other) const
75 {
76 return desc == other.desc;
77 }
78
79 Id id{};
80 ITensorInfo *desc{};
81 MemoryType memory_type{};
82 AuxMemoryInfo memory_info{};
83};
84
85struct ClKernel
86{
87public:
88 using Id = DependencyGraph::Id;
89 ClKernel() = default;
90 virtual ~ClKernel() = default;
91 ClKernel(const ClKernel &kernel) = default;
92 ClKernel &operator=(const ClKernel &kernel) = default;
93 ClKernel(ClKernel &&kernel) = default;
94 ClKernel &operator=(ClKernel &&kernel) = default;
95 ClKernel(const ClKernelGraph *graph, Id id, const ClKernelConfig &config, const ITensorDescPack<ClKernelTensor> &tensors)
96 : _graph{ graph }, _id{ id }, _config{ config }, _tensors{ tensors }
97 {
98 }
99 virtual bool operator==(const ClKernel &other) const = 0;
100 virtual Complexity complexity() const = 0;
101 virtual Status generate(ClKernelBlueprint &bp) const = 0;
102 Id id() const
103 {
104 return _id;
105 }
106 ITensorDescPack<ClKernelTensor> tensors() const
107 {
108 return _tensors;
109 }
110 ClKernelConfig config() const
111 {
112 return _config;
113 }
114
115protected:
116 const ClKernelGraph *_graph {};
117 Id _id{};
118 ClKernelConfig _config{};
119 ITensorDescPack<ClKernelTensor> _tensors{};
120};
121
122struct ClDirectConv2dKernel : public ClKernel
123{
124public:
125 Complexity complexity() const override
126 {
127 return Complexity::Complex;
128 }
129 ClDirectConv2dKernel() = default;
130 ~ClDirectConv2dKernel() override = default;
131 ClDirectConv2dKernel(const ClKernelGraph *graph, Id id, const ClKernelConfig config, const ClDirectConv2dKernelDescriptor &desc, const ITensorDescPack<ClKernelTensor> tensors)
132 : ClKernel{ graph, id, config, tensors }, desc{ desc }
133 {
134 }
135 static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ClDirectConv2dKernelDescriptor &conv2d_desc);
136 bool operator==(const ClKernel &other) const override;
137 Status generate(ClKernelBlueprint &bp) const override;
138
139 ClDirectConv2dKernelDescriptor desc{};
140};
141
142struct ClAddKernel : public ClKernel
143{
144public:
145 Complexity complexity() const override
146 {
147 return Complexity::Simple;
148 }
149 ClAddKernel() = default;
150 ~ClAddKernel() override = default;
151 ClAddKernel(const ClKernelGraph *graph, Id id, const ClKernelConfig &config, const ClEltwiseAddKernelDescriptor &desc, const ITensorDescPack<ClKernelTensor> tensors)
152 : ClKernel{ graph, id, config, tensors }, desc{ desc }
153 {
154 }
155 static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst);
156 bool operator==(const ClKernel &other) const override;
157 Status generate(ClKernelBlueprint &bp) const override;
158
159 ClEltwiseAddKernelDescriptor desc{};
160};
161
162struct ClKernelGraph
163{
164public:
165 using Id = DependencyGraph::Id;
166 using KernelMap = std::map<Id, utils::memory::deep_unique_ptr<ClKernel>>;
167 using KernelTensorMap = std::map<Id, utils::memory::deep_unique_ptr<ClKernelTensor>>;
168
169 ClKernelGraph() = default;
170 ~ClKernelGraph() = default;
171
172 friend bool operator==(const ClKernelGraph &graph0, const ClKernelGraph &graph1)
173 {
174 return graph0.graph == graph1.graph && graph0.kernels == graph1.kernels && graph0.tensors == graph1.tensors;
175 }
176
177 Status add_kernel_tensor(ITensorInfo *desc, MemoryType memory_type, const AuxMemoryInfo &memory_info, Id &tensor_id, Id merge_point = DependencyGraph::empty_id())
178 {
179 tensor_id = graph.add_tensor(merge_point);
180 if(tensors.find(tensor_id) == tensors.end())
181 {
182 tensors[tensor_id] = utils::memory::make_deep_unique<ClKernelTensor, ClKernelTensor>(tensor_id, desc, memory_type, memory_info);
183 }
184 return Status{};
185 }
186
187 template <typename ContentT, typename KernelDescT>
188 Status add_kernel(const ClKernelConfig &config, const KernelDescT &desc, const ITensorDescPack<ClKernelTensor> &tensors, Id &kernel_id)
189 {
190 const auto src_tensors = tensors.get_const_src_tensors();
191 const auto dst_tensors = tensors.get_const_dst_tensors();
192 std::vector<Id> src_tensor_ids{};
193 std::vector<Id> dst_tensor_ids{};
194 for(const auto &t : src_tensors)
195 {
196 src_tensor_ids.push_back(t->id);
197 }
198 for(const auto &t : dst_tensors)
199 {
200 dst_tensor_ids.push_back(t->id);
201 }
202 kernel_id = graph.add_operator(src_tensor_ids, dst_tensor_ids).second;
203 auto k = utils::memory::make_deep_unique<ClKernel, ContentT>(this, kernel_id, config, desc, tensors);
204 kernels[kernel_id] = std::move(k);
205 return Status{};
206 }
207
208 ClKernel *get_kernel(Id id)
209 {
210 return kernels.at(id).get();
211 }
212 const ClKernel *get_kernel(Id id) const
213 {
214 return kernels.at(id).get();
215 }
216
217 ClKernelTensor *get_tensor(Id id)
218 {
219 return tensors.at(id).get();
220 }
221 const ClKernelTensor *get_tensor(Id id) const
222 {
223 return tensors.at(id).get();
224 }
225
226 DependencyGraph graph{};
227 KernelMap kernels{};
228 KernelTensorMap tensors{};
229};
230using Id = DependencyGraph::Id;
231
232std::vector<const ClKernel *> traverse(const ClKernelGraph &graph);
233std::vector<ClKernel *> traverse(ClKernelGraph &graph);
234
235} // namespace dynamic_fusion
236} // namespace experimental
237} // namespace arm_compute
SiCong Li4e9f5682022-05-10 10:15:59 +0100238#endif //ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_CLKERNELGRAPH_H
239#endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */