blob: 17ae49740b76b09977ccdc51fe4e68a208a473e7 [file] [log] [blame]
Georgios Pinitas0c29cd32017-10-18 17:29:27 +01001/*
Michele Di Giorgiodde9ec92018-02-13 15:24:04 +00002 * Copyright (c) 2017-2018 ARM Limited.
Georgios Pinitas0c29cd32017-10-18 17:29:27 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_GRAPH_NODE_CONTEXT_H__
25#define __ARM_COMPUTE_GRAPH_NODE_CONTEXT_H__
26
27#include "arm_compute/core/Error.h"
28#include "arm_compute/graph/NodeParameter.h"
29#include "arm_compute/graph/Types.h"
30#include "support/ToolchainSupport.h"
31
32#include <map>
33#include <memory>
34#include <string>
35
36namespace arm_compute
37{
38namespace graph
39{
40/** Node Context class
41 *
42 * Node context class is used to hold all the parameters required by a node to execute
43 */
44class NodeContext
45{
46public:
47 /** Default Constructor
Georgios Pinitas0c29cd32017-10-18 17:29:27 +010048 *
49 * @param[in] operation Name of the operation
50 */
Georgios Pinitas407c3e62017-10-25 18:26:46 +010051 NodeContext(OperationType operation)
Georgios Pinitas0c29cd32017-10-18 17:29:27 +010052 : _operation(operation), _target(TargetHint::DONT_CARE), _inputs(), _outputs(), _parameters() {};
53 /** Sets the execution target of the node
54 *
55 * @param[in] target Execution target of the node
56 */
57 void set_target(TargetHint target);
58 /** Adds an input tensor to the context
59 *
60 * @param[in] input Input to add
61 */
62 void add_input(arm_compute::ITensor *input);
Michele Di Giorgiodde9ec92018-02-13 15:24:04 +000063 /** Adds an output to the context
Georgios Pinitas0c29cd32017-10-18 17:29:27 +010064 *
65 * @param[in] output Output to add
66 */
67 void add_output(arm_compute::ITensor *output);
68 /** Adds a parameter to the context
69 *
70 * @param[in] name Parameter name
71 * @param[in] parameter Parameter to add
72 */
73 template <typename T>
74 void add_parameter(std::string name, T parameter);
75 /** Returns the operation of this node.
76 *
Georgios Pinitas407c3e62017-10-25 18:26:46 +010077 * @return The operation type
Georgios Pinitas0c29cd32017-10-18 17:29:27 +010078 */
Georgios Pinitas407c3e62017-10-25 18:26:46 +010079 OperationType operation() const;
Georgios Pinitas0c29cd32017-10-18 17:29:27 +010080 /** Returns the execution target of this node
81 *
82 * @return The execution target
83 */
84 TargetHint target() const;
85 /** Returns input tensor of a given index
86 *
87 * @param[in] idx Index of the input tensor
88 *
89 * @return A pointer the requested input tensor else nullptr
90 */
91 arm_compute::ITensor *input(size_t idx) const;
92 /** Returns output tensor of a given index
93 *
94 * @param[in] idx Index of the output tensor
95 *
96 * @return A pointer the requested output tensor else nullptr
97 */
98 arm_compute::ITensor *output(size_t idx) const;
99 /** Returns the parameter with the given name
100 *
101 * @param[in] name Parameter name
102 *
103 * @return The requested parameter else an empty object
104 */
105 template <typename T>
106 T parameter(std::string name) const;
107 /** Returns number of inputs
108 *
109 * @return Number of inputs
110 */
111 size_t num_inputs() const;
112 /** Returns number of output
113 *
114 * @return Number of outputs
115 */
116 size_t num_outputs() const;
117
118private:
Georgios Pinitas407c3e62017-10-25 18:26:46 +0100119 OperationType _operation;
Georgios Pinitas0c29cd32017-10-18 17:29:27 +0100120 TargetHint _target;
121 std::vector<arm_compute::ITensor *> _inputs;
122 std::vector<arm_compute::ITensor *> _outputs;
123 std::map<std::string, std::unique_ptr<NodeParameterBase>> _parameters;
124};
125
126template <typename T>
127inline void NodeContext::add_parameter(std::string name, T parameter)
128{
129 ARM_COMPUTE_ERROR_ON_MSG(_parameters.find(name) != _parameters.end(), "Parameter already exists!");
130 _parameters[name] = support::cpp14::make_unique<NodeParameter<T>>(name, parameter);
131}
132
133template <typename T>
134inline T NodeContext::parameter(std::string name) const
135{
136 auto it = _parameters.find(name);
137 ARM_COMPUTE_ERROR_ON(it == _parameters.end());
138 return static_cast<NodeParameter<T> *>(it->second.get())->value();
139}
140} // namespace graph
141} // namespace arm_compute
142#endif /* __ARM_COMPUTE_GRAPH_NODE_CONTEXT_H__ */