blob: ded38572825638b7c0a00cb5cd485385a4fe26e4 [file] [log] [blame]
narpra01b9546cf2018-11-20 15:21:28 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Matteo Martincighf02e6cd2019-05-17 12:15:30 +01007#include "CommonTestUtils.hpp"
8
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +01009#include <ResolveType.hpp>
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000010
narpra01b9546cf2018-11-20 15:21:28 +000011#include <armnn/INetwork.hpp>
12
narpra01b9546cf2018-11-20 15:21:28 +000013#include <boost/test/unit_test.hpp>
14
15#include <vector>
16
17namespace
18{
19
20template<typename armnn::DataType DataType>
Jim Flynne242f2d2019-05-22 14:24:13 +010021INetworkPtr CreateConcatNetwork(const std::vector<TensorShape>& inputShapes,
22 const TensorShape &outputShape,
narpra01b9546cf2018-11-20 15:21:28 +000023 unsigned int concatAxis,
24 const float qScale = 1.0f,
25 const int32_t qOffset = 0)
26{
27 using namespace armnn;
28 // Builds up the structure of the network.
29 INetworkPtr net(INetwork::Create());
30
31 OriginsDescriptor descriptor;
32
Jim Flynn825af452019-05-20 12:49:28 +010033 descriptor = CreateDescriptorForConcatenation(inputShapes.begin(),
34 inputShapes.end(),
35 concatAxis);
Jim Flynne242f2d2019-05-22 14:24:13 +010036 IConnectableLayer* concat = net->AddConcatLayer(descriptor, "concat");
narpra01b9546cf2018-11-20 15:21:28 +000037
38 for (unsigned int i = 0; i < inputShapes.size(); ++i)
39 {
40 TensorInfo inputTensorInfo(inputShapes[i], DataType, qScale, qOffset);
41 IConnectableLayer* input = net->AddInputLayer(boost::numeric_cast<LayerBindingId>(i));
Jim Flynne242f2d2019-05-22 14:24:13 +010042 Connect(input, concat, inputTensorInfo, 0, i);
narpra01b9546cf2018-11-20 15:21:28 +000043 }
44
45 TensorInfo outputTensorInfo(outputShape, DataType, qScale, qOffset);
46 IConnectableLayer* output = net->AddOutputLayer(0, "output");
Jim Flynne242f2d2019-05-22 14:24:13 +010047 Connect(concat, output, outputTensorInfo, 0, 0);
narpra01b9546cf2018-11-20 15:21:28 +000048
49 return net;
50}
51
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000052template<armnn::DataType ArmnnType>
Jim Flynne242f2d2019-05-22 14:24:13 +010053void ConcatDim0EndToEnd(const std::vector<BackendId>& backends)
narpra01b9546cf2018-11-20 15:21:28 +000054{
55 using namespace armnn;
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000056 using T = ResolveType<ArmnnType>;
narpra01b9546cf2018-11-20 15:21:28 +000057
58 unsigned int concatAxis = 0;
59 const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
60 const TensorShape& outputShape = { 4, 3, 2, 2 };
61
62 // Builds up the structure of the network
Jim Flynne242f2d2019-05-22 14:24:13 +010063 INetworkPtr net = CreateConcatNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
narpra01b9546cf2018-11-20 15:21:28 +000064
65 BOOST_TEST_CHECKPOINT("create a network");
66
67 // Creates structures for input & output.
68 std::vector<T> inputData{
69 1, 2,
70 3, 4,
71 5, 6,
72 7, 8,
73 9, 10,
74 11, 12,
75 1, 2,
76 3, 4,
77 5, 6,
78 7, 8,
79 9, 10,
80 11, 12
81 };
82
83 std::vector<T> expectedOutput{
84 1, 2,
85 3, 4,
86 5, 6,
87 7, 8,
88 9, 10,
89 11, 12,
90 1, 2,
91 3, 4,
92 5, 6,
93 7, 8,
94 9, 10,
95 11, 12,
96 1, 2,
97 3, 4,
98 5, 6,
99 7, 8,
100 9, 10,
101 11, 12,
102 1, 2,
103 3, 4,
104 5, 6,
105 7, 8,
106 9, 10,
107 11, 12
108 };
109
110 std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }, { 1,inputData }};
111 std::map<int, std::vector<T>> expectedOutputData = {{ 0,expectedOutput }};
112
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000113 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01b9546cf2018-11-20 15:21:28 +0000114}
115
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000116template<armnn::DataType ArmnnType>
Jim Flynne242f2d2019-05-22 14:24:13 +0100117void ConcatDim1EndToEnd(const std::vector<BackendId>& backends)
narpra01b9546cf2018-11-20 15:21:28 +0000118{
119 using namespace armnn;
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000120 using T = ResolveType<ArmnnType>;
narpra01b9546cf2018-11-20 15:21:28 +0000121
122 unsigned int concatAxis = 1;
123 const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
124 const TensorShape& outputShape = { 2, 6, 2, 2 };
125
126 // Builds up the structure of the network
Jim Flynne242f2d2019-05-22 14:24:13 +0100127 INetworkPtr net = CreateConcatNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
narpra01b9546cf2018-11-20 15:21:28 +0000128
129 BOOST_TEST_CHECKPOINT("create a network");
130
131 // Creates structures for input & output.
132 std::vector<T> inputData{
133 1, 2,
134 3, 4,
135 5, 6,
136 7, 8,
137 9, 10,
138 11, 12,
139 1, 2,
140 3, 4,
141 5, 6,
142 7, 8,
143 9, 10,
144 11, 12
145 };
146
147 std::vector<T> expectedOutput{
148 1, 2,
149 3, 4,
150 5, 6,
151 7, 8,
152 9, 10,
153 11, 12,
154 1, 2,
155 3, 4,
156 5, 6,
157 7, 8,
158 9, 10,
159 11, 12,
160 1, 2,
161 3, 4,
162 5, 6,
163 7, 8,
164 9, 10,
165 11, 12,
166 1, 2,
167 3, 4,
168 5, 6,
169 7, 8,
170 9, 10,
171 11, 12
172 };
173
174 std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }, { 1,inputData }};
175 std::map<int, std::vector<T>> expectedOutputData = {{ 0,expectedOutput }};
176
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000177 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01b9546cf2018-11-20 15:21:28 +0000178}
179
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000180template<armnn::DataType ArmnnType>
Jim Flynne242f2d2019-05-22 14:24:13 +0100181void ConcatDim2EndToEnd(const std::vector<BackendId>& backends)
narpra01b9546cf2018-11-20 15:21:28 +0000182{
183 using namespace armnn;
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000184 using T = ResolveType<ArmnnType>;
narpra01b9546cf2018-11-20 15:21:28 +0000185
186 unsigned int concatAxis = 2;
187 const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
188 const TensorShape& outputShape = { 2, 3, 4, 2 };
189
190 // Builds up the structure of the network
Jim Flynne242f2d2019-05-22 14:24:13 +0100191 INetworkPtr net = CreateConcatNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
narpra01b9546cf2018-11-20 15:21:28 +0000192
193 BOOST_TEST_CHECKPOINT("create a network");
194
195 // Creates structures for input & output.
196 std::vector<T> inputData{
197 1, 2,
198 3, 4,
199 5, 6,
200 7, 8,
201 9, 10,
202 11, 12,
203 1, 2,
204 3, 4,
205 5, 6,
206 7, 8,
207 9, 10,
208 11, 12
209 };
210
211 std::vector<T> expectedOutput{
212 1, 2,
213 3, 4,
214 1, 2,
215 3, 4,
216 5, 6,
217 7, 8,
218 5, 6,
219 7, 8,
220 9, 10,
221 11, 12,
222 9, 10,
223 11, 12,
224 1, 2,
225 3, 4,
226 1, 2,
227 3, 4,
228 5, 6,
229 7, 8,
230 5, 6,
231 7, 8,
232 9, 10,
233 11, 12,
234 9, 10,
235 11, 12
236 };
237
238 std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }, { 1,inputData }};
239 std::map<int, std::vector<T>> expectedOutputData = {{ 0,expectedOutput }};
240
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000241 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01b9546cf2018-11-20 15:21:28 +0000242}
243
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000244template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
Jim Flynne242f2d2019-05-22 14:24:13 +0100245void ConcatDim3EndToEnd(const std::vector<BackendId>& backends)
narpra01b9546cf2018-11-20 15:21:28 +0000246{
247 using namespace armnn;
248
249 unsigned int concatAxis = 3;
250 const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
251 const TensorShape& outputShape = { 2, 3, 2, 4 };
252
253 // Builds up the structure of the network
Jim Flynne242f2d2019-05-22 14:24:13 +0100254 INetworkPtr net = CreateConcatNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
narpra01b9546cf2018-11-20 15:21:28 +0000255
256 BOOST_TEST_CHECKPOINT("create a network");
257
258 // Creates structures for input & output.
259 std::vector<T> inputData{
260 1, 2,
261 3, 4,
262 5, 6,
263 7, 8,
264 9, 10,
265 11, 12,
266 1, 2,
267 3, 4,
268 5, 6,
269 7, 8,
270 9, 10,
271 11, 12
272 };
273
274 std::vector<T> expectedOutput{
275 1, 2,
276 1, 2,
277 3, 4,
278 3, 4,
279 5, 6,
280 5, 6,
281 7, 8,
282 7, 8,
283 9, 10,
284 9, 10,
285 11, 12,
286 11, 12,
287 1, 2,
288 1, 2,
289 3, 4,
290 3, 4,
291 5, 6,
292 5, 6,
293 7, 8,
294 7, 8,
295 9, 10,
296 9, 10,
297 11, 12,
298 11, 12
299 };
300
301 std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }, { 1,inputData }};
302 std::map<int, std::vector<T>> expectedOutputData = {{ 0,expectedOutput }};
303
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000304 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01b9546cf2018-11-20 15:21:28 +0000305}
306
307} // anonymous namespace