blob: f118d7d8515a9026ec3a5b6f0e7f0dc7df94fd83 [file] [log] [blame]
SiCong Lif44bbc52022-08-29 18:25:51 +01001/*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef SRC_DYNAMIC_FUSION_SKETCH_ARGUMENTPACK
25#define SRC_DYNAMIC_FUSION_SKETCH_ARGUMENTPACK
26
27#include "arm_compute/core/experimental/Types.h"
28#include <unordered_map>
29#include <vector>
30
31namespace arm_compute
32{
33namespace experimental
34{
35namespace dynamic_fusion
36{
37/** This is a generic class that packs the arguments of an operator. For now, it is only used for tensor-related types
38 * Examples of "tensor-related types": @ref ITensorInfo, @ref ITensor, @ref ICLTensor
39 *
40 * The argument id is the position of the argument within the pack, and is represented by @ref TensorType
41 *
42 * @tparam T Tensor-related type
43 */
44template <typename T>
45class ArgumentPack
46{
47public:
48 /** @ref TensorType encodes the position of a tensor argument within the pack */
49 using Id = TensorType;
50 /** A single argument element within the pack
51 * It contains either a const pointer or a non-const pointer to the Tensor-related type T, but never at the same time
52 */
53 struct PackElement
54 {
55 PackElement() = default;
56 PackElement(const PackElement &elem) = default;
57 PackElement &operator=(const PackElement &elem) = default;
58 PackElement(PackElement &&elem) = default;
59 PackElement &operator=(PackElement &&elem) = default;
60 PackElement(Id id, T *tensor)
61 : id(id), tensor(tensor), ctensor(nullptr)
62 {
63 }
64 PackElement(Id id, const T *ctensor)
65 : id(id), tensor(nullptr), ctensor(ctensor)
66 {
67 }
68
69 Id id{ ACL_UNKNOWN }; /**< Argument id within the pack */
70 T *tensor{ nullptr }; /**< Non-const pointer to tensor-related object */
71 const T *ctensor
72 {
73 nullptr
74 }; /**< Const pointer to tensor-related object */
75 };
76
77public:
78 /** Default constructor */
79 ArgumentPack() = default;
80 /** Destructor */
81 ~ArgumentPack() = default;
82 /** Allow instances of this class to be copy constructed */
83 ArgumentPack<T>(const ArgumentPack<T> &other) = default;
84 /** Allow instances of this class to be copied */
85 ArgumentPack<T> &operator=(const ArgumentPack<T> &other) = default;
86 /** Allow instances of this class to be move constructed */
87 ArgumentPack<T>(ArgumentPack<T> &&other) = default;
88 /** Allow instances of this class to be moved */
89 ArgumentPack<T> &operator=(ArgumentPack<T> &&other) = default;
90 /** Initializer list Constructor */
91 ArgumentPack(const std::initializer_list<PackElement> &l)
92 : _pack{}
93 {
94 for(const auto &e : l)
95 {
96 _pack[e.id] = e;
97 }
98 }
99 /** Add tensor to the pack
100 *
101 * @param[in] id ID of the tensor to add
102 * @param[in] tensor Tensor to add
103 */
104 void add_tensor(Id id, T *tensor)
105 {
106 _pack[id] = PackElement(id, tensor);
107 }
108 /** Add const tensor to the pack
109 *
110 * @param[in] id ID of the tensor to add
111 * @param[in] tensor Tensor to add
112 */
113 void add_const_tensor(Id id, const T *tensor)
114 {
115 _pack[id] = PackElement(id, tensor);
116 }
117 /** Get tensor of a given id from the pack
118 *
119 * @param[in] id ID of tensor to extract
120 *
121 * @return The pointer to the tensor if exist and is non-const else nullptr
122 */
123 T *get_tensor(Id id)
124 {
125 auto it = _pack.find(id);
126 return it != _pack.end() ? it->second.tensor : nullptr;
127 }
128 /** Get constant tensor of a given id
129 *
130 * @param[in] id ID of tensor to extract
131 *
132 * @return The pointer to the tensor (const or not) if exist else nullptr
133 */
134 const T *get_const_tensor(Id id) const
135 {
136 auto it = _pack.find(id);
137 if(it != _pack.end())
138 {
139 return it->second.ctensor != nullptr ? it->second.ctensor : it->second.tensor;
140 }
141 return nullptr;
142 }
143 /** Remove the tensor stored with the given id
144 *
145 * @param[in] id ID of tensor to remove
146 */
147 void remove_tensor(Id id)
148 {
149 _pack.erase(id);
150 }
151 /** Pack size accessor
152 *
153 * @return Number of tensors registered to the pack
154 */
155 size_t size() const
156 {
157 return _pack.size();
158 }
159 /** Checks if pack is empty
160 *
161 * @return True if empty else false
162 */
163 bool empty() const
164 {
165 return _pack.empty();
166 }
167 /** Get the ACL_SRC_* tensors
168 *
169 * @return std::vector<T *>
170 */
171 std::vector<T *> get_src_tensors()
172 {
173 std::vector<T *> src_tensors{};
174 for(int id = static_cast<int>(TensorType::ACL_SRC); id <= static_cast<int>(TensorType::ACL_SRC_END); ++id)
175 {
176 auto tensor = get_tensor(static_cast<TensorType>(id));
177 if(tensor != nullptr)
178 {
179 src_tensors.push_back(tensor);
180 }
181 }
182 return src_tensors;
183 }
184 /** Get the const ACL_SRC_* tensors
185 *
186 * @return std::vector<const T *>
187 */
188 std::vector<const T *> get_const_src_tensors() const
189 {
190 std::vector<const T *> src_tensors{};
191 for(int id = static_cast<int>(TensorType::ACL_SRC); id <= static_cast<int>(TensorType::ACL_SRC_END); ++id)
192 {
193 auto tensor = get_const_tensor(static_cast<TensorType>(id));
194 if(tensor != nullptr)
195 {
196 src_tensors.push_back(tensor);
197 }
198 }
199 return src_tensors;
200 }
201 /** Get the ACL_DST_* tensors
202 *
203 * @return std::vector<T *>
204 */
205 std::vector<T *> get_dst_tensors()
206 {
207 std::vector<T *> dst_tensors{};
208 for(int id = static_cast<int>(TensorType::ACL_DST); id <= static_cast<int>(TensorType::ACL_DST_END); ++id)
209 {
210 auto tensor = get_tensor(static_cast<TensorType>(id));
211 if(tensor != nullptr)
212 {
213 dst_tensors.push_back(tensor);
214 }
215 }
216 return dst_tensors;
217 }
218 /** Get the const ACL_DST_* tensors
219 *
220 * @return std::vector<const T *>
221 */
222 std::vector<const T *> get_const_dst_tensors() const
223 {
224 std::vector<const T *> dst_tensors{};
225 for(int id = static_cast<int>(TensorType::ACL_DST); id <= static_cast<int>(TensorType::ACL_DST_END); ++id)
226 {
227 auto tensor = get_const_tensor(static_cast<TensorType>(id));
228 if(tensor != nullptr)
229 {
230 dst_tensors.push_back(tensor);
231 }
232 }
233 return dst_tensors;
234 }
235
236private:
237 std::unordered_map<int, PackElement> _pack{}; /**< Container with the packed tensors */
238};
239} // namespace dynamic_fusion
240} // namespace experimental
241} // namespace arm_compute
242#endif /* SRC_DYNAMIC_FUSION_SKETCH_ARGUMENTPACK */