blob: 65219998cb2e842770a199659962faf33f6c1da7 [file] [log] [blame]
SiCong Li7061eb22021-01-08 15:16:02 +00001/*
2 * Copyright (c) 2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "src/runtime/CL/mlgo/HeuristicTree.h"
25#include "arm_compute/core/Log.h"
26
27#include <algorithm>
28#include <deque>
29#include <set>
30namespace arm_compute
31{
32namespace mlgo
33{
34namespace
35{
36bool evaluate(GEMMShape shape, Condition cond)
37{
38 // PRE: all features and ConditionalOps are valid
39 constexpr float eps = 0.0001f;
40 // Calculate all secondary features
41 std::vector<std::pair<std::string, float>> cond_values
42 {
43 { "m", static_cast<float>(shape.m) },
44 { "n", static_cast<float>(shape.n) },
45 { "k", static_cast<float>(shape.k) },
46 { "b", static_cast<float>(shape.b) },
47 { "r_mn", static_cast<float>(shape.m) / shape.n },
48 { "r_mk", static_cast<float>(shape.m) / shape.k },
49 { "r_nk", static_cast<float>(shape.n) / shape.k },
50 { "r_mnk", static_cast<float>(shape.m) / (static_cast<float>(shape.n) / shape.k) },
51 { "workload", (static_cast<float>(shape.m) * shape.n * shape.b) / 20.0 }
52 };
53 auto cond_value_pair_it = std::find_if(cond_values.begin(), cond_values.end(),
54 [&cond](decltype(*cond_values.begin()) it)
55 {
56 return it.first == cond.feature;
57 });
58
59 ARM_COMPUTE_ERROR_ON(cond_value_pair_it == cond_values.end());
60 const float cond_value = cond_value_pair_it->second;
61 switch(cond.op)
62 {
63 case ConditionalOp::LT:
64 {
65 return cond_value < cond.threshold;
66 }
67 case ConditionalOp::LE:
68 {
69 return cond_value <= cond.threshold;
70 }
71 case ConditionalOp::GT:
72 {
73 return cond_value > cond.threshold;
74 }
75 case ConditionalOp::GE:
76 {
77 return cond_value >= cond.threshold;
78 }
79 case ConditionalOp::EQ:
80 default:
81 {
82 return std::abs(cond_value - cond.threshold) < eps;
83 }
84 }
85}
86
87} // namespace
88
89constexpr size_t HeuristicTree::_max_num_nodes;
90constexpr size_t HeuristicTree::_max_query_depth;
91constexpr HeuristicTree::NodeID HeuristicTree::_root;
92
93HeuristicTree::HeuristicTree()
94 : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32)
95{
96}
97
98HeuristicTree::HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type)
99 : _id{ id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{ data_type }, _tree{}
100{
101}
102
103template <typename T>
104std::pair<bool, T> HeuristicTree::query(GEMMShape shape) const
105{
106 // Root ID = 0;
107 auto cur_node = _tree.at(_root).get();
108 size_t depth = 0;
109 while(cur_node->type() != NodeType::Leaf)
110 {
111 if(depth > _max_query_depth)
112 {
113 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", _max_query_depth);
114 return std::make_pair(false, T{});
115 }
116 ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType");
117 auto br_node = dynamic_cast<BranchNode *>(cur_node);
118 if(evaluate(shape, br_node->condition))
119 {
120 cur_node = _tree.at(br_node->true_node).get();
121 }
122 else
123 {
124 cur_node = _tree.at(br_node->false_node).get();
125 }
126 ++depth;
127 }
128 ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Leaf, "Unexpected NodeType");
129 auto l_node = dynamic_cast<LeafNode<T> *>(cur_node);
130 return std::make_pair(true, l_node->value);
131}
132
133template <typename T>
134bool HeuristicTree::add_leaf(NodeID id, T val)
135{
136 if(_tree.size() >= _max_num_nodes)
137 {
138 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
139 return false;
140 }
141 if(_tree.find(id) != _tree.end())
142 {
143 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
144 return false;
145 }
146 _tree[id] = std::make_unique<LeafNode<T>>(id, val);
147 return true;
148}
149
150bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID f_node)
151{
152 if(_tree.size() >= _max_num_nodes)
153 {
154 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
155 return false;
156 }
157
158 const std::set<std::string> supported_features =
159 {
160 "m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload"
161 };
162 const auto orig_feature = cond.feature;
163 std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), [](char c)
164 {
165 return std::tolower(c);
166 });
167 if(supported_features.find(cond.feature) == supported_features.end())
168 {
169 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str());
170 return false;
171 }
172
173 if(_tree.find(id) != _tree.end())
174 {
175 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
176 return false;
177 }
178 _tree[id] = std::make_unique<BranchNode>(id, cond, t_node, f_node);
179 return true;
180}
181
182bool HeuristicTree::check_if_structurally_correct() const
183{
184 std::set<NodeID> visited;
185 std::deque<NodeID> to_visit{ _root };
186
187 while(!to_visit.empty())
188 {
189 auto id = to_visit.front();
190 to_visit.pop_front();
191 if(_tree.find(id) == _tree.end())
192 {
193 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing node %zu", id);
194 return false;
195 }
196 auto not_seen_before = visited.insert(id);
197 if(!not_seen_before.second)
198 {
199 ARM_COMPUTE_LOG_INFO_MSG_CORE("Not a tree; contains cycles or loops");
200 return false;
201 }
202 auto cur_node = _tree.at(id).get();
203 if(cur_node->type() == NodeType::Branch)
204 {
205 auto br_node = dynamic_cast<BranchNode *>(cur_node);
206 to_visit.push_back(br_node->true_node);
207 to_visit.push_back(br_node->false_node);
208 }
209 }
210 if(visited.size() != _tree.size())
211 {
212 ARM_COMPUTE_LOG_INFO_MSG_CORE("Contains disjoint nodes");
213 return false;
214 }
215 return true;
216}
217
218bool HeuristicTree::check()
219{
220 if(_tree.empty())
221 {
222 ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered");
223 return false;
224 }
225 if(_tree.find(_root) == _tree.end())
226 {
227 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root);
228 return false;
229 }
230 return check_if_structurally_correct();
231}
232
233/** Explicit template instantiation @relates HeuristicTree */
234template std::pair<bool, GEMMType> HeuristicTree::query<GEMMType>(GEMMShape shape) const;
235/** Explicit template instantiation @relates HeuristicTree */
236template std::pair<bool, GEMMConfigNative> HeuristicTree::query<GEMMConfigNative>(GEMMShape shape) const;
237/** Explicit template instantiation @relates HeuristicTree */
238template std::pair<bool, GEMMConfigReshapedOnlyRHS> HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(GEMMShape shape) const;
239/** Explicit template instantiation @relates HeuristicTree */
240template std::pair<bool, GEMMConfigReshaped> HeuristicTree::query<GEMMConfigReshaped>(GEMMShape shape) const;
241
242/** Explicit template instantiation @relates HeuristicTree */
243template bool HeuristicTree::add_leaf(NodeID id, GEMMType val);
244/** Explicit template instantiation @relates HeuristicTree */
245template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigNative val);
246/** Explicit template instantiation @relates HeuristicTree */
247template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshapedOnlyRHS val);
248/** Explicit template instantiation @relates HeuristicTree */
249template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshaped val);
250
251} // namespace mlgo
252
253} // namespace arm_compute