blob: 7a20a565b8d15a1996a60e7f9b2dfa387abd4155 [file] [log] [blame]
Georgios Pinitase2c82fe2017-10-02 18:51:47 +01001/*
Giorgio Arenaa66eaa22017-12-21 19:50:06 +00002 * Copyright (c) 2017-2018 ARM Limited.
Georgios Pinitase2c82fe2017-10-02 18:51:47 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/graph/nodes/BranchLayer.h"
25
Michalis Spyroued194b12017-10-31 15:04:34 +000026#include "arm_compute/graph/Error.h"
Georgios Pinitase2c82fe2017-10-02 18:51:47 +010027#include "arm_compute/graph/Graph.h"
28#include "arm_compute/graph/SubGraph.h"
29#include "arm_compute/graph/Tensor.h"
30#include "arm_compute/runtime/IFunction.h"
31#include "support/ToolchainSupport.h"
32#include "utils/TypePrinter.h"
33
34#include <memory>
35#include <tuple>
36#include <vector>
37
38using namespace arm_compute::graph;
39
Georgios Pinitase2c82fe2017-10-02 18:51:47 +010040/** Branch function */
41class BranchFunction final : public arm_compute::IFunction
42{
43public:
44 /** Default Constructor */
45 BranchFunction()
46 : _graphs()
47 {
48 }
49 /** Registers graph to be executed by the branch function
50 *
51 * @param[in] graph Graph to register
52 */
53 void register_graph(std::unique_ptr<Graph> graph)
54 {
55 _graphs.push_back(std::move(graph));
56 }
57 // Inherited methods overriden:
58 void run() override
59 {
60 for(auto &g : _graphs)
61 {
62 ARM_COMPUTE_ERROR_ON(g.get() == nullptr);
63 g->run();
64 }
65 }
66
67private:
68 std::vector<std::unique_ptr<Graph>> _graphs;
69};
70
71std::unique_ptr<arm_compute::IFunction> BranchLayer::instantiate_node(GraphContext &ctx, ITensorObject *input, ITensorObject *output)
72{
73 ARM_COMPUTE_ERROR_ON(_branch_merge_method != BranchMergeMethod::DEPTH_CONCATENATE);
74 ARM_COMPUTE_UNUSED(_branch_merge_method);
Michalis Spyroued194b12017-10-31 15:04:34 +000075 ARM_COMPUTE_ERROR_ON_UNALLOCATED_TENSOR_OBJECT(input, output);
Georgios Pinitase2c82fe2017-10-02 18:51:47 +010076
77 // Create branch function
78 auto func = arm_compute::support::cpp14::make_unique<BranchFunction>();
79
Georgios Pinitas652bde52018-01-10 15:33:28 +000080 // Track output depth
81 int depth = 0;
Georgios Pinitase2c82fe2017-10-02 18:51:47 +010082
83 // Constuct all sub-graphs given the input/output
84 for(auto &sg : _sub_graphs)
85 {
86 ARM_COMPUTE_ERROR_ON(sg.get() == nullptr);
87
88 // IO buffers
89 std::unique_ptr<ITensorObject> in;
90 std::unique_ptr<ITensorObject> out;
91 SubTensor *out_sub_tensor = nullptr;
92
93 // Create input sub-tensor
94 if(!sg->has_input())
95 {
96 ARM_COMPUTE_ERROR_ON(dynamic_cast<Tensor *>(input) == nullptr);
97 in = arm_compute::support::cpp14::make_unique<SubTensor>(*dynamic_cast<Tensor *>(input),
98 input->tensor()->info()->tensor_shape(),
99 Coordinates());
100 }
101
102 // Create output sub-tensor
103 if(!sg->has_output())
104 {
Georgios Pinitas652bde52018-01-10 15:33:28 +0000105 ARM_COMPUTE_ERROR_ON((dynamic_cast<Tensor *>(output) == nullptr) && (dynamic_cast<SubTensor *>(output) == nullptr));
106
107 out = arm_compute::support::cpp14::make_unique<SubTensor>(output->tensor(),
108 TensorShape(),
109 Coordinates(0, 0, depth),
110 output->target(),
111 true);
Georgios Pinitase2c82fe2017-10-02 18:51:47 +0100112 out_sub_tensor = dynamic_cast<SubTensor *>(out.get());
113 }
114
115 // Construct sub_graph
Georgios Pinitas1250a5a2018-01-02 13:27:37 +0000116 auto g = sg->construct(ctx, std::move(in), std::move(out));
Georgios Pinitase2c82fe2017-10-02 18:51:47 +0100117
118 // Register graph to function
119 func->register_graph(std::move(g));
120
121 // Update and track depth
122 if(out_sub_tensor != nullptr)
123 {
124 ARM_COMPUTE_ERROR_ON(out_sub_tensor->tensor() == nullptr);
125 depth += out_sub_tensor->tensor()->info()->tensor_shape()[2];
Georgios Pinitase2c82fe2017-10-02 18:51:47 +0100126 }
127 }
128
Georgios Pinitase2c82fe2017-10-02 18:51:47 +0100129 return std::move(func);
130}