blob: d5c7de22155c2ec40b528e92fc7c22d56b906abe [file] [log] [blame]
SiCong Li7061eb22021-01-08 15:16:02 +00001/*
2 * Copyright (c) 2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
SiCong Li70858d82021-02-05 09:19:51 +000024#ifndef SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H
25#define SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H
SiCong Li7061eb22021-01-08 15:16:02 +000026
27#include "arm_compute/core/Types.h"
28#include "src/runtime/CL/mlgo/Common.h"
29
30#include <map>
31#include <memory>
32#include <string>
33#include <utility>
34
35namespace arm_compute
36{
37namespace mlgo
38{
39/** Conditional ops */
40enum class ConditionalOp
41{
42 EQ, /**< Equal */
43 LT, /**< Less than */
44 LE, /**< Less than or equal to */
45 GT, /**< Greater than */
46 GE, /**< Greater than or equal to */
47};
48
49/** A branch condition expression evaluating: feature op threshold */
50struct Condition
51{
52 std::string feature; /**< Feature name */
53 ConditionalOp op; /**< Condtional op */
54 float threshold; /**< Threshold value */
55};
56
57/** GEMM Shape used for query */
58struct GEMMShape
59{
60 unsigned int m; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */
61 unsigned int n; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */
62 unsigned int k; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */
63 unsigned int b; /**< Batch size */
64};
65
66/** A binary decision tree based heuristic */
67class HeuristicTree
68{
69public:
70 using NodeID = size_t;
71 using TreeID = size_t;
72 using Index = std::tuple<HeuristicType, std::string, DataType>;
73 enum class NodeType
74 {
75 Branch,
76 Leaf
77 };
78 struct Node
79 {
80 virtual NodeType type() const = 0;
81 virtual ~Node() = default;
82 };
83
84 struct BranchNode : public Node
85 {
86 BranchNode(NodeID id, Condition cond, NodeID t_node, NodeID f_node)
87 : id{ id }, condition{ cond }, true_node{ t_node }, false_node{ f_node }
88 {
89 }
90 NodeType type() const override
91 {
92 return NodeType::Branch;
93 }
94 NodeID id;
95 Condition condition;
96 NodeID true_node;
97 NodeID false_node;
98 };
99
100 template <typename T>
101 struct LeafNode : public Node
102 {
103 LeafNode(NodeID id, T val)
104 : id{ id }, value{ val }
105 {
106 }
107 NodeType type() const override
108 {
109 return NodeType::Leaf;
110 }
111 NodeID id;
112 T value;
113 };
114
115public:
116 /** Constructor */
117 HeuristicTree();
118 /** Constructor */
119 HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type);
120 // Since the HeuristicTree is a handle that owns the the nodes, it is move-only
121 /** Prevent copy construction */
122 HeuristicTree(const HeuristicTree &) = delete;
123 /** Prevent copy assignment */
124 HeuristicTree &operator=(const HeuristicTree &) = delete;
125 /** Move constructor */
126 HeuristicTree(HeuristicTree &&other) noexcept = default;
127 /** Move assignment */
Manuel Bottini8e213312021-02-08 17:07:04 +0000128 HeuristicTree &operator=(HeuristicTree &&other) = default;
SiCong Li7061eb22021-01-08 15:16:02 +0000129
130 /** Query a leaf value given a gemm shape
131 *
132 * @tparam T Leaf value type
133 * @param shape A @ref GEMMShape for the query
134 * @return std::pair<bool, T> Outcome contains bool, signalling if the query succeeded or not
135 */
136 template <typename T>
137 std::pair<bool, T> query(GEMMShape shape) const;
138
139 /** Add a leaf node
140 *
141 * @tparam T Leaf value type
142 * @param id Leaf node ID
143 * @param leaf_value Leaf node value
144 * @return bool If the addition succeeded or not
145 */
146 template <typename T>
147 bool add_leaf(NodeID id, T leaf_value);
148 /** Add a branch node
149 *
150 * @param id Branch node ID
151 * @param cond Branch node @ref Condition
152 * @param true_node True node's ID
153 * @param false_node False node's ID
154 * @return bool If the addition succeeded or not
155 */
156 bool add_branch(NodeID id, Condition cond, NodeID true_node, NodeID false_node);
157
158 /** Get tree ID
159 * @return TreeID
160 */
161 TreeID id() const
162 {
163 return _id;
164 }
165
166 /** Get tree index
167 * @return Index
168 */
169 Index index() const
170 {
171 return std::make_tuple(_heuristic_type, _ip_target, _data_type);
172 }
173
174 /** Check if tree is valid
175 * @return bool
176 */
177 bool check();
178
179private:
180 static constexpr size_t _max_query_depth{ 1000 }; // Maximum depth of query
181 static constexpr size_t _max_num_nodes{ 100000 }; // Maximum number of nodes contained by the tree
182 static constexpr NodeID _root{ 0 }; // Root tree ID
183
184private:
185 bool check_if_structurally_correct() const;
186
187private:
188 TreeID _id; /**< Heuristic tree ID */
189 HeuristicType _heuristic_type; /**< Heuristic type */
190 std::string _ip_target; /**< IP target associated with the tree */
191 DataType _data_type; /**< Data type associated with the tree */
192 std::map<NodeID, std::unique_ptr<Node>> _tree; /**< Tree representation */
193};
194} // namespace mlgo
195
196} // namespace arm_compute
197
SiCong Li70858d82021-02-05 09:19:51 +0000198#endif //SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H