blob: 2614523e8ef39587103c0b1a6e4efe10c35f835f [file] [log] [blame]
narpra01b9546cf2018-11-20 15:21:28 +00001//
Kevin May5b58e312022-12-15 10:15:21 +00002// Copyright © 2017,2022 Arm Ltd. All rights reserved.
narpra01b9546cf2018-11-20 15:21:28 +00003// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Sadik Armagana097d2a2021-11-24 15:47:28 +00007#include <CommonTestUtils.hpp>
Matteo Martincighf02e6cd2019-05-17 12:15:30 +01008
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +01009#include <ResolveType.hpp>
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000010
narpra01b9546cf2018-11-20 15:21:28 +000011#include <armnn/INetwork.hpp>
12
Matthew Sloyan171214c2020-09-09 09:07:37 +010013#include <armnn/utility/NumericCast.hpp>
14
Sadik Armagan1625efc2021-06-10 18:24:34 +010015#include <doctest/doctest.h>
narpra01b9546cf2018-11-20 15:21:28 +000016
17#include <vector>
18
19namespace
20{
21
22template<typename armnn::DataType DataType>
Jim Flynne242f2d2019-05-22 14:24:13 +010023INetworkPtr CreateConcatNetwork(const std::vector<TensorShape>& inputShapes,
24 const TensorShape &outputShape,
narpra01b9546cf2018-11-20 15:21:28 +000025 unsigned int concatAxis,
26 const float qScale = 1.0f,
27 const int32_t qOffset = 0)
28{
29 using namespace armnn;
30 // Builds up the structure of the network.
31 INetworkPtr net(INetwork::Create());
32
33 OriginsDescriptor descriptor;
34
Jim Flynn825af452019-05-20 12:49:28 +010035 descriptor = CreateDescriptorForConcatenation(inputShapes.begin(),
36 inputShapes.end(),
37 concatAxis);
Jim Flynne242f2d2019-05-22 14:24:13 +010038 IConnectableLayer* concat = net->AddConcatLayer(descriptor, "concat");
narpra01b9546cf2018-11-20 15:21:28 +000039
40 for (unsigned int i = 0; i < inputShapes.size(); ++i)
41 {
Cathal Corbett5b8093c2021-10-22 11:12:07 +010042 TensorInfo inputTensorInfo(inputShapes[i], DataType, qScale, qOffset, true);
Matthew Sloyan171214c2020-09-09 09:07:37 +010043 IConnectableLayer* input = net->AddInputLayer(armnn::numeric_cast<LayerBindingId>(i));
Jim Flynne242f2d2019-05-22 14:24:13 +010044 Connect(input, concat, inputTensorInfo, 0, i);
narpra01b9546cf2018-11-20 15:21:28 +000045 }
46
47 TensorInfo outputTensorInfo(outputShape, DataType, qScale, qOffset);
48 IConnectableLayer* output = net->AddOutputLayer(0, "output");
Jim Flynne242f2d2019-05-22 14:24:13 +010049 Connect(concat, output, outputTensorInfo, 0, 0);
narpra01b9546cf2018-11-20 15:21:28 +000050
51 return net;
52}
53
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000054template<armnn::DataType ArmnnType>
Jim Flynne242f2d2019-05-22 14:24:13 +010055void ConcatDim0EndToEnd(const std::vector<BackendId>& backends)
narpra01b9546cf2018-11-20 15:21:28 +000056{
57 using namespace armnn;
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000058 using T = ResolveType<ArmnnType>;
narpra01b9546cf2018-11-20 15:21:28 +000059
60 unsigned int concatAxis = 0;
61 const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
62 const TensorShape& outputShape = { 4, 3, 2, 2 };
63
64 // Builds up the structure of the network
Jim Flynne242f2d2019-05-22 14:24:13 +010065 INetworkPtr net = CreateConcatNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
narpra01b9546cf2018-11-20 15:21:28 +000066
Sadik Armagan1625efc2021-06-10 18:24:34 +010067 CHECK(net);
narpra01b9546cf2018-11-20 15:21:28 +000068
69 // Creates structures for input & output.
70 std::vector<T> inputData{
71 1, 2,
72 3, 4,
73 5, 6,
74 7, 8,
75 9, 10,
76 11, 12,
77 1, 2,
78 3, 4,
79 5, 6,
80 7, 8,
81 9, 10,
82 11, 12
83 };
84
85 std::vector<T> expectedOutput{
86 1, 2,
87 3, 4,
88 5, 6,
89 7, 8,
90 9, 10,
91 11, 12,
92 1, 2,
93 3, 4,
94 5, 6,
95 7, 8,
96 9, 10,
97 11, 12,
98 1, 2,
99 3, 4,
100 5, 6,
101 7, 8,
102 9, 10,
103 11, 12,
104 1, 2,
105 3, 4,
106 5, 6,
107 7, 8,
108 9, 10,
109 11, 12
110 };
111
112 std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }, { 1,inputData }};
113 std::map<int, std::vector<T>> expectedOutputData = {{ 0,expectedOutput }};
114
Kevin May5b58e312022-12-15 10:15:21 +0000115 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(net), inputTensorData, expectedOutputData, backends);
narpra01b9546cf2018-11-20 15:21:28 +0000116}
117
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000118template<armnn::DataType ArmnnType>
Jim Flynne242f2d2019-05-22 14:24:13 +0100119void ConcatDim1EndToEnd(const std::vector<BackendId>& backends)
narpra01b9546cf2018-11-20 15:21:28 +0000120{
121 using namespace armnn;
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000122 using T = ResolveType<ArmnnType>;
narpra01b9546cf2018-11-20 15:21:28 +0000123
124 unsigned int concatAxis = 1;
125 const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
126 const TensorShape& outputShape = { 2, 6, 2, 2 };
127
128 // Builds up the structure of the network
Jim Flynne242f2d2019-05-22 14:24:13 +0100129 INetworkPtr net = CreateConcatNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
narpra01b9546cf2018-11-20 15:21:28 +0000130
narpra01b9546cf2018-11-20 15:21:28 +0000131 // Creates structures for input & output.
132 std::vector<T> inputData{
133 1, 2,
134 3, 4,
135 5, 6,
136 7, 8,
137 9, 10,
138 11, 12,
139 1, 2,
140 3, 4,
141 5, 6,
142 7, 8,
143 9, 10,
144 11, 12
145 };
146
147 std::vector<T> expectedOutput{
148 1, 2,
149 3, 4,
150 5, 6,
151 7, 8,
152 9, 10,
153 11, 12,
154 1, 2,
155 3, 4,
156 5, 6,
157 7, 8,
158 9, 10,
159 11, 12,
160 1, 2,
161 3, 4,
162 5, 6,
163 7, 8,
164 9, 10,
165 11, 12,
166 1, 2,
167 3, 4,
168 5, 6,
169 7, 8,
170 9, 10,
171 11, 12
172 };
173
174 std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }, { 1,inputData }};
175 std::map<int, std::vector<T>> expectedOutputData = {{ 0,expectedOutput }};
176
Kevin May5b58e312022-12-15 10:15:21 +0000177 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(net), inputTensorData, expectedOutputData, backends);
narpra01b9546cf2018-11-20 15:21:28 +0000178}
179
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000180template<armnn::DataType ArmnnType>
Jim Flynne242f2d2019-05-22 14:24:13 +0100181void ConcatDim2EndToEnd(const std::vector<BackendId>& backends)
narpra01b9546cf2018-11-20 15:21:28 +0000182{
183 using namespace armnn;
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000184 using T = ResolveType<ArmnnType>;
narpra01b9546cf2018-11-20 15:21:28 +0000185
186 unsigned int concatAxis = 2;
187 const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
188 const TensorShape& outputShape = { 2, 3, 4, 2 };
189
190 // Builds up the structure of the network
Jim Flynne242f2d2019-05-22 14:24:13 +0100191 INetworkPtr net = CreateConcatNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
narpra01b9546cf2018-11-20 15:21:28 +0000192
narpra01b9546cf2018-11-20 15:21:28 +0000193 // Creates structures for input & output.
194 std::vector<T> inputData{
195 1, 2,
196 3, 4,
197 5, 6,
198 7, 8,
199 9, 10,
200 11, 12,
201 1, 2,
202 3, 4,
203 5, 6,
204 7, 8,
205 9, 10,
206 11, 12
207 };
208
209 std::vector<T> expectedOutput{
210 1, 2,
211 3, 4,
212 1, 2,
213 3, 4,
214 5, 6,
215 7, 8,
216 5, 6,
217 7, 8,
218 9, 10,
219 11, 12,
220 9, 10,
221 11, 12,
222 1, 2,
223 3, 4,
224 1, 2,
225 3, 4,
226 5, 6,
227 7, 8,
228 5, 6,
229 7, 8,
230 9, 10,
231 11, 12,
232 9, 10,
233 11, 12
234 };
235
236 std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }, { 1,inputData }};
237 std::map<int, std::vector<T>> expectedOutputData = {{ 0,expectedOutput }};
238
Kevin May5b58e312022-12-15 10:15:21 +0000239 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(net), inputTensorData, expectedOutputData, backends);
narpra01b9546cf2018-11-20 15:21:28 +0000240}
241
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000242template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
Jim Flynne242f2d2019-05-22 14:24:13 +0100243void ConcatDim3EndToEnd(const std::vector<BackendId>& backends)
narpra01b9546cf2018-11-20 15:21:28 +0000244{
245 using namespace armnn;
246
247 unsigned int concatAxis = 3;
248 const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
249 const TensorShape& outputShape = { 2, 3, 2, 4 };
250
251 // Builds up the structure of the network
Jim Flynne242f2d2019-05-22 14:24:13 +0100252 INetworkPtr net = CreateConcatNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
narpra01b9546cf2018-11-20 15:21:28 +0000253
narpra01b9546cf2018-11-20 15:21:28 +0000254 // Creates structures for input & output.
255 std::vector<T> inputData{
256 1, 2,
257 3, 4,
258 5, 6,
259 7, 8,
260 9, 10,
261 11, 12,
262 1, 2,
263 3, 4,
264 5, 6,
265 7, 8,
266 9, 10,
267 11, 12
268 };
269
270 std::vector<T> expectedOutput{
271 1, 2,
272 1, 2,
273 3, 4,
274 3, 4,
275 5, 6,
276 5, 6,
277 7, 8,
278 7, 8,
279 9, 10,
280 9, 10,
281 11, 12,
282 11, 12,
283 1, 2,
284 1, 2,
285 3, 4,
286 3, 4,
287 5, 6,
288 5, 6,
289 7, 8,
290 7, 8,
291 9, 10,
292 9, 10,
293 11, 12,
294 11, 12
295 };
296
297 std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }, { 1,inputData }};
298 std::map<int, std::vector<T>> expectedOutputData = {{ 0,expectedOutput }};
299
Kevin May5b58e312022-12-15 10:15:21 +0000300 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(net), inputTensorData, expectedOutputData, backends);
narpra01b9546cf2018-11-20 15:21:28 +0000301}
302
303} // anonymous namespace