blob: f7b706902b091559b55e10319a356ada4c10168d [file] [log] [blame]
SiCong Li7061eb22021-01-08 15:16:02 +00001/*
2 * Copyright (c) 2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "src/runtime/CL/mlgo/HeuristicTree.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010025
SiCong Li7061eb22021-01-08 15:16:02 +000026#include "arm_compute/core/Log.h"
27
Georgios Pinitasb0cd5d82021-02-22 18:17:43 +000028#include "support/Cast.h"
29
SiCong Li7061eb22021-01-08 15:16:02 +000030#include <algorithm>
31#include <deque>
32#include <set>
33namespace arm_compute
34{
35namespace mlgo
36{
37namespace
38{
39bool evaluate(GEMMShape shape, Condition cond)
40{
41 // PRE: all features and ConditionalOps are valid
42 constexpr float eps = 0.0001f;
43 // Calculate all secondary features
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010044 std::vector<std::pair<std::string, float>> cond_values{
45 {"m", static_cast<float>(shape.m)},
46 {"n", static_cast<float>(shape.n)},
47 {"k", static_cast<float>(shape.k)},
48 {"b", static_cast<float>(shape.b)},
49 {"r_mn", static_cast<float>(shape.m) / shape.n},
50 {"r_mk", static_cast<float>(shape.m) / shape.k},
51 {"r_nk", static_cast<float>(shape.n) / shape.k},
52 {"r_mnk", static_cast<float>(shape.m) / (static_cast<float>(shape.n) / shape.k)},
53 {"workload", (static_cast<float>(shape.m) * shape.n * shape.b) / 20.0}};
54 auto cond_value_pair_it =
55 std::find_if(cond_values.begin(), cond_values.end(),
56 [&cond](decltype(*cond_values.begin()) it) { return it.first == cond.feature; });
SiCong Li7061eb22021-01-08 15:16:02 +000057
58 ARM_COMPUTE_ERROR_ON(cond_value_pair_it == cond_values.end());
59 const float cond_value = cond_value_pair_it->second;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010060 switch (cond.op)
SiCong Li7061eb22021-01-08 15:16:02 +000061 {
62 case ConditionalOp::LT:
63 {
64 return cond_value < cond.threshold;
65 }
66 case ConditionalOp::LE:
67 {
68 return cond_value <= cond.threshold;
69 }
70 case ConditionalOp::GT:
71 {
72 return cond_value > cond.threshold;
73 }
74 case ConditionalOp::GE:
75 {
76 return cond_value >= cond.threshold;
77 }
78 case ConditionalOp::EQ:
79 default:
80 {
81 return std::abs(cond_value - cond.threshold) < eps;
82 }
83 }
84}
85
86} // namespace
87
88constexpr size_t HeuristicTree::_max_num_nodes;
89constexpr size_t HeuristicTree::_max_query_depth;
90constexpr HeuristicTree::NodeID HeuristicTree::_root;
91
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010092HeuristicTree::HeuristicTree() : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32)
SiCong Li7061eb22021-01-08 15:16:02 +000093{
94}
95
96HeuristicTree::HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010097 : _id{id}, _heuristic_type{h_type}, _ip_target{ip_target}, _data_type{data_type}, _tree{}
SiCong Li7061eb22021-01-08 15:16:02 +000098{
99}
100
101template <typename T>
102std::pair<bool, T> HeuristicTree::query(GEMMShape shape) const
103{
104 // Root ID = 0;
105 auto cur_node = _tree.at(_root).get();
106 size_t depth = 0;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100107 while (cur_node->type() != NodeType::Leaf)
SiCong Li7061eb22021-01-08 15:16:02 +0000108 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100109 if (depth > _max_query_depth)
SiCong Li7061eb22021-01-08 15:16:02 +0000110 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100111 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?",
112 _max_query_depth);
SiCong Li7061eb22021-01-08 15:16:02 +0000113 return std::make_pair(false, T{});
114 }
115 ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType");
Georgios Pinitasb0cd5d82021-02-22 18:17:43 +0000116 auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100117 if (evaluate(shape, br_node->condition))
SiCong Li7061eb22021-01-08 15:16:02 +0000118 {
119 cur_node = _tree.at(br_node->true_node).get();
120 }
121 else
122 {
123 cur_node = _tree.at(br_node->false_node).get();
124 }
125 ++depth;
126 }
127 ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Leaf, "Unexpected NodeType");
Georgios Pinitasb0cd5d82021-02-22 18:17:43 +0000128 auto l_node = utils::cast::polymorphic_downcast<LeafNode<T> *>(cur_node);
SiCong Li7061eb22021-01-08 15:16:02 +0000129 return std::make_pair(true, l_node->value);
130}
131
132template <typename T>
133bool HeuristicTree::add_leaf(NodeID id, T val)
134{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100135 if (_tree.size() >= _max_num_nodes)
SiCong Li7061eb22021-01-08 15:16:02 +0000136 {
137 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
138 return false;
139 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100140 if (_tree.find(id) != _tree.end())
SiCong Li7061eb22021-01-08 15:16:02 +0000141 {
142 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
143 return false;
144 }
145 _tree[id] = std::make_unique<LeafNode<T>>(id, val);
146 return true;
147}
148
149bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID f_node)
150{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100151 if (_tree.size() >= _max_num_nodes)
SiCong Li7061eb22021-01-08 15:16:02 +0000152 {
153 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
154 return false;
155 }
156
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100157 const std::set<std::string> supported_features = {"m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload"};
158 const auto orig_feature = cond.feature;
159 std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(),
160 [](char c) { return std::tolower(c); });
161 if (supported_features.find(cond.feature) == supported_features.end())
SiCong Li7061eb22021-01-08 15:16:02 +0000162 {
163 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str());
164 return false;
165 }
166
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100167 if (_tree.find(id) != _tree.end())
SiCong Li7061eb22021-01-08 15:16:02 +0000168 {
169 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
170 return false;
171 }
172 _tree[id] = std::make_unique<BranchNode>(id, cond, t_node, f_node);
173 return true;
174}
175
176bool HeuristicTree::check_if_structurally_correct() const
177{
178 std::set<NodeID> visited;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100179 std::deque<NodeID> to_visit{_root};
SiCong Li7061eb22021-01-08 15:16:02 +0000180
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100181 while (!to_visit.empty())
SiCong Li7061eb22021-01-08 15:16:02 +0000182 {
183 auto id = to_visit.front();
184 to_visit.pop_front();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100185 if (_tree.find(id) == _tree.end())
SiCong Li7061eb22021-01-08 15:16:02 +0000186 {
187 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing node %zu", id);
188 return false;
189 }
190 auto not_seen_before = visited.insert(id);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100191 if (!not_seen_before.second)
SiCong Li7061eb22021-01-08 15:16:02 +0000192 {
193 ARM_COMPUTE_LOG_INFO_MSG_CORE("Not a tree; contains cycles or loops");
194 return false;
195 }
196 auto cur_node = _tree.at(id).get();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100197 if (cur_node->type() == NodeType::Branch)
SiCong Li7061eb22021-01-08 15:16:02 +0000198 {
Georgios Pinitasb0cd5d82021-02-22 18:17:43 +0000199 auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
SiCong Li7061eb22021-01-08 15:16:02 +0000200 to_visit.push_back(br_node->true_node);
201 to_visit.push_back(br_node->false_node);
202 }
203 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100204 if (visited.size() != _tree.size())
SiCong Li7061eb22021-01-08 15:16:02 +0000205 {
206 ARM_COMPUTE_LOG_INFO_MSG_CORE("Contains disjoint nodes");
207 return false;
208 }
209 return true;
210}
211
212bool HeuristicTree::check()
213{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100214 if (_tree.empty())
SiCong Li7061eb22021-01-08 15:16:02 +0000215 {
216 ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered");
217 return false;
218 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100219 if (_tree.find(_root) == _tree.end())
SiCong Li7061eb22021-01-08 15:16:02 +0000220 {
221 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root);
222 return false;
223 }
224 return check_if_structurally_correct();
225}
226
227/** Explicit template instantiation @relates HeuristicTree */
228template std::pair<bool, GEMMType> HeuristicTree::query<GEMMType>(GEMMShape shape) const;
229/** Explicit template instantiation @relates HeuristicTree */
230template std::pair<bool, GEMMConfigNative> HeuristicTree::query<GEMMConfigNative>(GEMMShape shape) const;
231/** Explicit template instantiation @relates HeuristicTree */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100232template std::pair<bool, GEMMConfigReshapedOnlyRHS>
233HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(GEMMShape shape) const;
SiCong Li7061eb22021-01-08 15:16:02 +0000234/** Explicit template instantiation @relates HeuristicTree */
235template std::pair<bool, GEMMConfigReshaped> HeuristicTree::query<GEMMConfigReshaped>(GEMMShape shape) const;
236
237/** Explicit template instantiation @relates HeuristicTree */
238template bool HeuristicTree::add_leaf(NodeID id, GEMMType val);
239/** Explicit template instantiation @relates HeuristicTree */
240template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigNative val);
241/** Explicit template instantiation @relates HeuristicTree */
242template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshapedOnlyRHS val);
243/** Explicit template instantiation @relates HeuristicTree */
244template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshaped val);
245
246} // namespace mlgo
247
248} // namespace arm_compute