blob: 5d9c36e09864beabcc4070d3ff9a55c92a325a47 [file] [log] [blame]
Anthony Barbier2a07e182017-08-04 18:20:27 +01001/*
Georgios Pinitasd9eb2752018-04-03 13:44:29 +01002 * Copyright (c) 2018 ARM Limited.
Anthony Barbier2a07e182017-08-04 18:20:27 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_GRAPH_INODE_H__
25#define __ARM_COMPUTE_GRAPH_INODE_H__
26
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010027#include "arm_compute/core/Error.h"
28#include "arm_compute/graph/TensorDescriptor.h"
Anthony Barbier2a07e182017-08-04 18:20:27 +010029#include "arm_compute/graph/Types.h"
Anthony Barbier2a07e182017-08-04 18:20:27 +010030
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010031#include <set>
Anthony Barbier2a07e182017-08-04 18:20:27 +010032
33namespace arm_compute
34{
35namespace graph
36{
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010037// Forward declarations
38class Graph;
39class Edge;
40class INodeVisitor;
41class Tensor;
42
Anthony Barbier2a07e182017-08-04 18:20:27 +010043/** Node interface */
44class INode
45{
46public:
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010047 /** Constructor */
48 INode();
49 /** Destructor **/
Anthony Barbier2a07e182017-08-04 18:20:27 +010050 virtual ~INode() = default;
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010051 /** Prevent instances of this class from being copied (As this class contains pointers) */
52 INode(const INode &) = delete;
53 /** Prevent instances of this class from being copy assigned (As this class contains pointers) */
54 INode &operator=(const INode &) = delete;
55 /** Allow instances of this class to be moved */
56 INode(INode &&) = default;
57 /** Allow instances of this class to be move assigned */
58 INode &operator=(INode &&) = default;
59 /** Validate node
Anthony Barbier2a07e182017-08-04 18:20:27 +010060 *
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010061 * @return Status containing any errors
Anthony Barbier2a07e182017-08-04 18:20:27 +010062 */
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010063 virtual Status validate() = 0;
64 /** Returns node's type
Anthony Barbier2a07e182017-08-04 18:20:27 +010065 *
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010066 * @return Node's type
Anthony Barbier2a07e182017-08-04 18:20:27 +010067 */
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010068 virtual NodeType type() const = 0;
69 /** Accepts a node visitor
Michele Di Giorgiodde9ec92018-02-13 15:24:04 +000070 *
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010071 * @param[in] v Visitor to accept
Michele Di Giorgiodde9ec92018-02-13 15:24:04 +000072 */
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010073 virtual void accept(INodeVisitor &v) = 0;
74 /** Forwards descriptor information to outputs if possible
Michele Di Giorgiodde9ec92018-02-13 15:24:04 +000075 *
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010076 * @return True if descriptor information could be forwarded otherwise false
Michele Di Giorgiodde9ec92018-02-13 15:24:04 +000077 */
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010078 virtual bool forward_descriptors() = 0;
79 /** Calculates output configuration
80 *
81 * @param[in] idx Output index to configure
82 *
83 * @return Output descriptor configuration
84 */
85 virtual TensorDescriptor configure_output(size_t idx) const = 0;
86 /** Returns node's name
87 *
88 * @return Node name
89 */
90 std::string name() const;
91 /** Returns node's ID
92 *
93 * @return Node's ID
94 */
95 NodeID id() const;
96 /** Returns node's Graph
97 *
98 * @return Node's graph
99 */
100 const Graph *graph() const;
101 /** Returns node's Graph
102 *
103 * @return Node's graph
104 */
105 Graph *graph();
106 /** Sets the graph that this node is registered to
107 *
108 * @param[in] g Back reference to graph
109 */
110 void set_graph(Graph *g);
111 /** Sets the node id
112 *
113 * @param[in] id Node id
114 */
115 void set_id(NodeID id);
116 /** Sets common node parameters
117 *
118 * @param[in] common_params Common node parameters to set
119 */
120 void set_common_node_parameters(NodeParams common_params);
121 /** Sets target preference
122 *
123 * @note This is not the target that the graph executor might choose, its just an indication
124 *
125 * @param[in] target Target preference
126 */
127 void set_requested_target(Target target);
128 /** Sets the final execution target
129 *
130 * @note GraphManager might change this target
131 *
132 * @param[in] target Final execution target
133 */
134 void set_assigned_target(Target target);
135 /** Sets the output tensor of at a given index
136 *
137 * @note All edges will get updated
138 *
139 * @param[in] tid Tensor ID
140 * @param[in] idx Output index
141 */
142 void set_output_tensor(TensorID tid, size_t idx);
143 /** Returns inputs of the node
144 *
145 * @return Inputs of the node
146 */
147 const std::vector<TensorID> &inputs() const;
148 /** Returns outputs of the node
149 *
150 * @return Outputs of the node
151 */
152 const std::vector<TensorID> &outputs() const;
153 /** Returns input edge set
154 *
155 * @return Set of input edges
156 */
157 const std::vector<EdgeID> &input_edges() const;
158 /** Returns output edge set
159 *
160 * @return Set of output edges
161 */
162 const std::set<EdgeID> &output_edges() const;
163 /** Returns the tensor ID of a given input of the node
164 *
165 * @note Precondition : idx should be a valid input index
166 *
167 * @param[in] idx Index of the node input
168 *
169 * @return TensorID of the requested input
170 */
171 TensorID input_id(size_t idx) const;
172 /** Returns the tensor ID of a given output of the node
173 *
174 * @note Precondition : idx should be a valid output index
175 *
176 * @param[in] idx Index of the node output
177 *
178 * @return TensorID of the requested output
179 */
180 TensorID output_id(size_t idx) const;
181 /** Returns the tensor of a given input of the node
182 *
183 * @note Precondition : idx should be a valid input index
184 *
185 * @param[in] idx Index of the node input
186 *
187 * @return Tensor of the requested input
188 */
189 Tensor *input(size_t idx) const;
190 /** Returns the tensor of a given output of the node
191 *
192 * @note Precondition : idx should be a valid output index
193 *
194 * @param[in] idx Index of the node output
195 *
196 * @return Tensor of the requested output
197 */
198 Tensor *output(size_t idx) const;
199 /** Returns the edge ID of a given input of the node
200 *
201 * @note Precondition : idx should be a valid input index
202 *
203 * @param[in] idx Index of the node input
204 *
205 * @return EdgeID of the requested input
206 */
207 EdgeID input_edge_id(size_t idx) const;
208 /** Returns the edge of a given input of the node
209 *
210 * @note Precondition : idx should be a valid input index
211 *
212 * @param[in] idx Index of the node input
213 *
214 * @return Edge of the requested input
215 */
216 Edge *input_edge(size_t idx) const;
217 /** Returns number of inputs of the node
218 *
219 * @return Number of inputs
220 */
221 size_t num_inputs() const;
222 /** Returns number of outputs of the node
223 *
224 * @return Number of outputs
225 */
226 size_t num_outputs() const;
227 /** Returns requested target for this node
228 *
229 * @return Requested execution target
230 */
231 Target requested_target() const;
232 /** Returns assigned target for this node
233 *
234 * @return Assigned target of this node
235 */
236 Target assigned_target() const;
Anthony Barbier2a07e182017-08-04 18:20:27 +0100237
Anthony Barbier2a07e182017-08-04 18:20:27 +0100238protected:
Georgios Pinitasd9eb2752018-04-03 13:44:29 +0100239 friend class Graph;
Anthony Barbier2a07e182017-08-04 18:20:27 +0100240
241protected:
Georgios Pinitasd9eb2752018-04-03 13:44:29 +0100242 Graph *_graph; /**< Backward reference to graph owning the node */
243 NodeID _id; /**< Node ID */
244 NodeParams _common_params; /**< Node common params */
245 std::vector<TensorID> _outputs; /**< Output of the node */
246 std::vector<EdgeID> _input_edges; /**< Inputs edge set */
247 std::set<EdgeID> _output_edges; /**< Output edge set */
248 Target _assigned_target; /**< Assigned target by the Graph executor */
Anthony Barbier2a07e182017-08-04 18:20:27 +0100249};
250} // namespace graph
251} // namespace arm_compute
252#endif /* __ARM_COMPUTE_GRAPH_INODE_H__ */