blob: 8fd9021ebc0a1c2989541046ae4dc49e2e61a3bd [file] [log] [blame]
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "armnnOnnxParser/IOnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8#include "OnnxParserTestUtils.hpp"
9
10TEST_SUITE("OnnxParser_Gather")
11{
12
13struct GatherMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14{
15 GatherMainFixture(const std::vector<int>& indicesShape,
16 const std::vector<int>& indices,
17 const std::vector<int>& inputShape,
18 const std::vector<int>& outputShape)
19 {
20 m_Prototext = R"(
21 ir_version: 8
22 producer_name: "onnx-example"
23 graph {
24 node {
25 output: "indices"
26 op_type: "Constant"
27 attribute {
28 name: "value"
29 t {
30 data_type: 7
31 )" + ConstructIndicesString(indicesShape, indices) + R"(
32 name: "value"
33 }
34 type: TENSOR
35 }
36 }
37 node {
38 input: "input"
39 input: "indices"
40 output: "output"
41 op_type: "Gather"
42 attribute {
43 name: "axis"
44 i: 0
45 type: INT
46 }
47 }
48 name: "gather-model"
49 input {
50 name: "input"
51 type {
52 tensor_type {
53 elem_type: 1
54 shape {
55 )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
56 }
57 }
58 }
59 }
60 output {
61 name: "output"
62 type {
63 tensor_type {
64 elem_type: 1
65 shape {
66 )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
67 }
68 }
69 }
70 }
71 })";
72 }
73 std::string ConstructIndicesString(const std::vector<int>& indicesShape, const std::vector<int>& indices)
74 {
75 std::string shapeStr;
76 for (int i : indicesShape)
77 {
78 shapeStr = fmt::format(" {} dims: {}", shapeStr, i);
79 }
80 for (int i : indices)
81 {
82 shapeStr = fmt::format(" {} int64_data: {}", shapeStr, i);
83 }
84 return shapeStr;
85 }
86};
87
Narumol Prangnawarat452274c2021-09-23 16:12:19 +010088struct GatherScalarFixture : GatherMainFixture
89{
90 GatherScalarFixture() : GatherMainFixture({ }, { 0 }, { 8 }, { })
91 {
92 Setup();
93 }
94};
95
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +010096struct Gather1dFixture : GatherMainFixture
97{
98 Gather1dFixture() : GatherMainFixture({ 4 }, { 0, 2, 1, 5 }, { 8 }, { 4 })
99 {
100 Setup();
101 }
102};
103
104struct Gather2dFixture : GatherMainFixture
105{
106 Gather2dFixture() : GatherMainFixture({ 3 }, { 1, 3, 4 }, { 5, 2 }, { 3, 2 })
107 {
108 Setup();
109 }
110};
111
112struct Gather3dMultiIndicesFixture : GatherMainFixture
113{
114 Gather3dMultiIndicesFixture() : GatherMainFixture({ 2, 3 }, { 1, 2, 1, 2, 1, 0 }, { 3, 2, 3 }, { 2, 3, 2, 3 })
115 {
116 Setup();
117 }
118};
119
120struct Gather4dFixture : GatherMainFixture
121{
122 Gather4dFixture() : GatherMainFixture({ 3 }, { 0, 1, 3 }, { 5, 4, 3, 2 }, { 3, 4, 3, 2 })
123 {
124 Setup();
125 }
126};
127
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100128TEST_CASE_FIXTURE(GatherScalarFixture, "GatherScalarTest")
129{
130 RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}},
131 {{"output", { 1.0f }}});
132}
133
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100134TEST_CASE_FIXTURE(Gather1dFixture, "Gather1dTest")
135{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100136 RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}},
137 {{"output", { 1.0f, 3.0f, 2.0f, 6.0f }}});
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100138}
139
140TEST_CASE_FIXTURE(Gather2dFixture, "Gather2dTest")
141{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100142 RunTest<2, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}},
143 {{"output", { 3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}});
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100144}
145
146TEST_CASE_FIXTURE(Gather3dMultiIndicesFixture, "Gather3dMultiIndicesTest")
147{
148 RunTest<3, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
149 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
150 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }}},
151 {{"output", { 7.0f, 8.0f, 9.0f,
152 10.0f, 11.0f, 12.0f,
153 13.0f, 14.0f, 15.0f,
154 16.0f, 17.0f, 18.0f,
155 7.0f, 8.0f, 9.0f,
156 10.0f, 11.0f, 12.0f,
157 13.0f, 14.0f, 15.0f,
158 16.0f, 17.0f, 18.0f,
159 7.0f, 8.0f, 9.0f,
160 10.0f, 11.0f, 12.0f,
161 1.0f, 2.0f, 3.0f,
162 4.0f, 5.0f, 6.0f }}});
163}
164
165TEST_CASE_FIXTURE(Gather4dFixture, "Gather4dTest")
166{
167 RunTest<4, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
168 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
169 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
170 16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
171 21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
172 26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
173 31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
174 36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
175 41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
176 46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
177 51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
178 56.0f, 57.0f, 58.0f, 59.0f, 60.0f,
179 61.0f, 62.0f, 63.0f, 64.0f, 65.0f,
180 66.0f, 67.0f, 68.0f, 69.0f, 70.0f,
181 71.0f, 72.0f, 73.0f, 74.0f, 75.0f,
182 76.0f, 77.0f, 78.0f, 79.0f, 80.0f,
183 81.0f, 82.0f, 83.0f, 84.0f, 85.0f,
184 86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
185 91.0f, 92.0f, 93.0f, 94.0f, 95.0f,
186 96.0f, 97.0f, 98.0f, 99.0f, 100.0f,
187 101.0f, 102.0f, 103.0f, 104.0f, 105.0f,
188 106.0f, 107.0f, 108.0f, 109.0f, 110.0f,
189 111.0f, 112.0f, 113.0f, 114.0f, 115.0f,
190 116.0f, 117.0f, 118.0f, 119.0f, 120.0f }}},
191 {{"output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
192 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
193 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
194 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
195 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
196 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f,
197 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f,
198 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f,
199 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f,
200 79.0f, 80.0f, 81.0f, 82.0f, 83.0f, 84.0f,
201 85.0f, 86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
202 91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 96.0f }}});
203}
204
205struct GatherRawDataFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
206{
207 GatherRawDataFixture()
208 {
209 m_Prototext = R"(
210 ir_version: 8
211 producer_name: "onnx-example"
212 graph {
213 node {
214 output: "indices"
215 op_type: "Constant"
216 attribute {
217 name: "value"
218 t {
219 dims: 3
220 data_type: 7
221 raw_data:
222 "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000"
223 name: "value"
224 }
225 type: TENSOR
226 }
227 }
228 node {
229 input: "input"
230 input: "indices"
231 output: "output"
232 op_type: "Gather"
233 attribute {
234 name: "axis"
235 i: 0
236 type: INT
237 }
238 }
239 name: "gather-model"
240 input {
241 name: "input"
242 type {
243 tensor_type {
244 elem_type: 1
245 shape {
246 dim {
247 dim_value: 5
248 }
249 dim {
250 dim_value: 4
251 }
252 dim {
253 dim_value: 3
254 }
255 dim {
256 dim_value: 2
257 }
258 }
259 }
260 }
261 }
262 output {
263 name: "output"
264 type {
265 tensor_type {
266 elem_type: 1
267 shape {
268 dim {
269 dim_value: 3
270 }
271 dim {
272 dim_value: 4
273 }
274 dim {
275 dim_value: 3
276 }
277 dim {
278 dim_value: 2
279 }
280 }
281 }
282 }
283 }
284 })";
285 Setup();
286 }
287};
288
289TEST_CASE_FIXTURE(GatherRawDataFixture, "GatherRawDataTest")
290{
291 RunTest<4, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
292 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
293 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
294 16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
295 21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
296 26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
297 31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
298 36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
299 41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
300 46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
301 51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
302 56.0f, 57.0f, 58.0f, 59.0f, 60.0f,
303 61.0f, 62.0f, 63.0f, 64.0f, 65.0f,
304 66.0f, 67.0f, 68.0f, 69.0f, 70.0f,
305 71.0f, 72.0f, 73.0f, 74.0f, 75.0f,
306 76.0f, 77.0f, 78.0f, 79.0f, 80.0f,
307 81.0f, 82.0f, 83.0f, 84.0f, 85.0f,
308 86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
309 91.0f, 92.0f, 93.0f, 94.0f, 95.0f,
310 96.0f, 97.0f, 98.0f, 99.0f, 100.0f,
311 101.0f, 102.0f, 103.0f, 104.0f, 105.0f,
312 106.0f, 107.0f, 108.0f, 109.0f, 110.0f,
313 111.0f, 112.0f, 113.0f, 114.0f, 115.0f,
314 116.0f, 117.0f, 118.0f, 119.0f, 120.0f }}},
315 {{"output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
316 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
317 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
318 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
319 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
320 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f,
321 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f,
322 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f,
323 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f,
324 79.0f, 80.0f, 81.0f, 82.0f, 83.0f, 84.0f,
325 85.0f, 86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
326 91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 96.0f }}});
327}
328
329}