blob: 35ab2bc861ad3d7c40b48c698d6033afc5b65e86 [file] [log] [blame]
narpra01b9546cf2018-11-20 15:21:28 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +01007#include <ResolveType.hpp>
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +00008
narpra01b9546cf2018-11-20 15:21:28 +00009#include <armnn/INetwork.hpp>
10
11#include <backendsCommon/test/CommonTestUtils.hpp>
12
13#include <boost/test/unit_test.hpp>
14
15#include <vector>
16
17namespace
18{
19
20template<typename armnn::DataType DataType>
21INetworkPtr CreateMergerNetwork(const std::vector<TensorShape>& inputShapes,
22 const TensorShape& outputShape,
23 unsigned int concatAxis,
24 const float qScale = 1.0f,
25 const int32_t qOffset = 0)
26{
27 using namespace armnn;
28 // Builds up the structure of the network.
29 INetworkPtr net(INetwork::Create());
30
31 OriginsDescriptor descriptor;
32
33 descriptor = CreateMergerDescriptorForConcatenation(inputShapes.begin(),
34 inputShapes.end(),
35 concatAxis);
Jim Flynn906f9462019-05-10 13:55:21 +010036 ARMNN_NO_DEPRECATE_WARN_BEGIN
narpra01b9546cf2018-11-20 15:21:28 +000037 IConnectableLayer* merger = net->AddMergerLayer(descriptor, "merger");
Jim Flynn906f9462019-05-10 13:55:21 +010038 ARMNN_NO_DEPRECATE_WARN_END
narpra01b9546cf2018-11-20 15:21:28 +000039
40 for (unsigned int i = 0; i < inputShapes.size(); ++i)
41 {
42 TensorInfo inputTensorInfo(inputShapes[i], DataType, qScale, qOffset);
43 IConnectableLayer* input = net->AddInputLayer(boost::numeric_cast<LayerBindingId>(i));
44 Connect(input, merger, inputTensorInfo, 0, i);
45 }
46
47 TensorInfo outputTensorInfo(outputShape, DataType, qScale, qOffset);
48 IConnectableLayer* output = net->AddOutputLayer(0, "output");
49 Connect(merger, output, outputTensorInfo, 0, 0);
50
51 return net;
52}
53
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000054template<armnn::DataType ArmnnType>
narpra01b9546cf2018-11-20 15:21:28 +000055void MergerDim0EndToEnd(const std::vector<BackendId>& backends)
56{
57 using namespace armnn;
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000058 using T = ResolveType<ArmnnType>;
narpra01b9546cf2018-11-20 15:21:28 +000059
60 unsigned int concatAxis = 0;
61 const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
62 const TensorShape& outputShape = { 4, 3, 2, 2 };
63
64 // Builds up the structure of the network
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000065 INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
narpra01b9546cf2018-11-20 15:21:28 +000066
67 BOOST_TEST_CHECKPOINT("create a network");
68
69 // Creates structures for input & output.
70 std::vector<T> inputData{
71 1, 2,
72 3, 4,
73 5, 6,
74 7, 8,
75 9, 10,
76 11, 12,
77 1, 2,
78 3, 4,
79 5, 6,
80 7, 8,
81 9, 10,
82 11, 12
83 };
84
85 std::vector<T> expectedOutput{
86 1, 2,
87 3, 4,
88 5, 6,
89 7, 8,
90 9, 10,
91 11, 12,
92 1, 2,
93 3, 4,
94 5, 6,
95 7, 8,
96 9, 10,
97 11, 12,
98 1, 2,
99 3, 4,
100 5, 6,
101 7, 8,
102 9, 10,
103 11, 12,
104 1, 2,
105 3, 4,
106 5, 6,
107 7, 8,
108 9, 10,
109 11, 12
110 };
111
112 std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }, { 1,inputData }};
113 std::map<int, std::vector<T>> expectedOutputData = {{ 0,expectedOutput }};
114
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000115 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01b9546cf2018-11-20 15:21:28 +0000116}
117
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000118template<armnn::DataType ArmnnType>
narpra01b9546cf2018-11-20 15:21:28 +0000119void MergerDim1EndToEnd(const std::vector<BackendId>& backends)
120{
121 using namespace armnn;
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000122 using T = ResolveType<ArmnnType>;
narpra01b9546cf2018-11-20 15:21:28 +0000123
124 unsigned int concatAxis = 1;
125 const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
126 const TensorShape& outputShape = { 2, 6, 2, 2 };
127
128 // Builds up the structure of the network
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000129 INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
narpra01b9546cf2018-11-20 15:21:28 +0000130
131 BOOST_TEST_CHECKPOINT("create a network");
132
133 // Creates structures for input & output.
134 std::vector<T> inputData{
135 1, 2,
136 3, 4,
137 5, 6,
138 7, 8,
139 9, 10,
140 11, 12,
141 1, 2,
142 3, 4,
143 5, 6,
144 7, 8,
145 9, 10,
146 11, 12
147 };
148
149 std::vector<T> expectedOutput{
150 1, 2,
151 3, 4,
152 5, 6,
153 7, 8,
154 9, 10,
155 11, 12,
156 1, 2,
157 3, 4,
158 5, 6,
159 7, 8,
160 9, 10,
161 11, 12,
162 1, 2,
163 3, 4,
164 5, 6,
165 7, 8,
166 9, 10,
167 11, 12,
168 1, 2,
169 3, 4,
170 5, 6,
171 7, 8,
172 9, 10,
173 11, 12
174 };
175
176 std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }, { 1,inputData }};
177 std::map<int, std::vector<T>> expectedOutputData = {{ 0,expectedOutput }};
178
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000179 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01b9546cf2018-11-20 15:21:28 +0000180}
181
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000182template<armnn::DataType ArmnnType>
narpra01b9546cf2018-11-20 15:21:28 +0000183void MergerDim2EndToEnd(const std::vector<BackendId>& backends)
184{
185 using namespace armnn;
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000186 using T = ResolveType<ArmnnType>;
narpra01b9546cf2018-11-20 15:21:28 +0000187
188 unsigned int concatAxis = 2;
189 const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
190 const TensorShape& outputShape = { 2, 3, 4, 2 };
191
192 // Builds up the structure of the network
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000193 INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
narpra01b9546cf2018-11-20 15:21:28 +0000194
195 BOOST_TEST_CHECKPOINT("create a network");
196
197 // Creates structures for input & output.
198 std::vector<T> inputData{
199 1, 2,
200 3, 4,
201 5, 6,
202 7, 8,
203 9, 10,
204 11, 12,
205 1, 2,
206 3, 4,
207 5, 6,
208 7, 8,
209 9, 10,
210 11, 12
211 };
212
213 std::vector<T> expectedOutput{
214 1, 2,
215 3, 4,
216 1, 2,
217 3, 4,
218 5, 6,
219 7, 8,
220 5, 6,
221 7, 8,
222 9, 10,
223 11, 12,
224 9, 10,
225 11, 12,
226 1, 2,
227 3, 4,
228 1, 2,
229 3, 4,
230 5, 6,
231 7, 8,
232 5, 6,
233 7, 8,
234 9, 10,
235 11, 12,
236 9, 10,
237 11, 12
238 };
239
240 std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }, { 1,inputData }};
241 std::map<int, std::vector<T>> expectedOutputData = {{ 0,expectedOutput }};
242
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000243 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01b9546cf2018-11-20 15:21:28 +0000244}
245
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000246template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
narpra01b9546cf2018-11-20 15:21:28 +0000247void MergerDim3EndToEnd(const std::vector<BackendId>& backends)
248{
249 using namespace armnn;
250
251 unsigned int concatAxis = 3;
252 const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
253 const TensorShape& outputShape = { 2, 3, 2, 4 };
254
255 // Builds up the structure of the network
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000256 INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
narpra01b9546cf2018-11-20 15:21:28 +0000257
258 BOOST_TEST_CHECKPOINT("create a network");
259
260 // Creates structures for input & output.
261 std::vector<T> inputData{
262 1, 2,
263 3, 4,
264 5, 6,
265 7, 8,
266 9, 10,
267 11, 12,
268 1, 2,
269 3, 4,
270 5, 6,
271 7, 8,
272 9, 10,
273 11, 12
274 };
275
276 std::vector<T> expectedOutput{
277 1, 2,
278 1, 2,
279 3, 4,
280 3, 4,
281 5, 6,
282 5, 6,
283 7, 8,
284 7, 8,
285 9, 10,
286 9, 10,
287 11, 12,
288 11, 12,
289 1, 2,
290 1, 2,
291 3, 4,
292 3, 4,
293 5, 6,
294 5, 6,
295 7, 8,
296 7, 8,
297 9, 10,
298 9, 10,
299 11, 12,
300 11, 12
301 };
302
303 std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }, { 1,inputData }};
304 std::map<int, std::vector<T>> expectedOutputData = {{ 0,expectedOutput }};
305
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000306 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01b9546cf2018-11-20 15:21:28 +0000307}
308
309} // anonymous namespace