blob: 47919c4481c6823f41129a2f34a62496d6500b95 [file] [log] [blame]
Saoirse Stewarta1ed73a2019-03-04 13:40:12 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
Saoirse Stewarta1ed73a2019-03-04 13:40:12 +00003// SPDX-License-Identifier: MIT
4//
5
Saoirse Stewarta1ed73a2019-03-04 13:40:12 +00006#include "ParserFlatbuffersSerializeFixture.hpp"
Finn Williams85d36712021-01-26 22:30:06 +00007#include <armnnDeserializer/IDeserializer.hpp>
Saoirse Stewarta1ed73a2019-03-04 13:40:12 +00008
9#include <string>
Saoirse Stewarta1ed73a2019-03-04 13:40:12 +000010
Sadik Armagan1625efc2021-06-10 18:24:34 +010011TEST_SUITE("Deserializer_Gather")
12{
Saoirse Stewarta1ed73a2019-03-04 13:40:12 +000013struct GatherFixture : public ParserFlatbuffersSerializeFixture
14{
Teresa Charlin52664732020-06-29 16:27:03 +010015 explicit GatherFixture(const std::string& inputShape,
16 const std::string& indicesShape,
17 const std::string& input1Content,
18 const std::string& outputShape,
19 const std::string& axis,
Saoirse Stewarta1ed73a2019-03-04 13:40:12 +000020 const std::string dataType,
21 const std::string constDataType)
22 {
23 m_JsonString = R"(
24 {
25 inputIds: [0],
26 outputIds: [3],
27 layers: [
28 {
29 layer_type: "InputLayer",
30 layer: {
31 base: {
32 layerBindingId: 0,
33 base: {
34 index: 0,
35 layerName: "InputLayer",
36 layerType: "Input",
37 inputSlots: [{
38 index: 0,
39 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
40 }],
41 outputSlots: [ {
42 index: 0,
43 tensorInfo: {
44 dimensions: )" + inputShape + R"(,
45 dataType: )" + dataType + R"(
46 }}]
47 }
48 }}},
49 {
50 layer_type: "ConstantLayer",
51 layer: {
52 base: {
53 index:1,
54 layerName: "ConstantLayer",
55 layerType: "Constant",
56 outputSlots: [ {
57 index: 0,
58 tensorInfo: {
59 dimensions: )" + indicesShape + R"(,
60 dataType: "Signed32",
61 },
62 }],
63 },
64 input: {
65 info: {
66 dimensions: )" + indicesShape + R"(,
67 dataType: )" + dataType + R"(
68 },
69 data_type: )" + constDataType + R"(,
70 data: {
71 data: )" + input1Content + R"(,
72 } }
73 },},
74 {
75 layer_type: "GatherLayer",
76 layer: {
77 base: {
78 index: 2,
79 layerName: "GatherLayer",
80 layerType: "Gather",
81 inputSlots: [
82 {
83 index: 0,
84 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
85 },
86 {
87 index: 1,
88 connection: {sourceLayerIndex:1, outputSlotIndex:0 }
89 }],
90 outputSlots: [ {
91 index: 0,
92 tensorInfo: {
93 dimensions: )" + outputShape + R"(,
94 dataType: )" + dataType + R"(
95
Teresa Charlin52664732020-06-29 16:27:03 +010096 }}]},
97 descriptor: {
98 axis: )" + axis + R"(
99 }
Saoirse Stewarta1ed73a2019-03-04 13:40:12 +0000100 }},
101 {
102 layer_type: "OutputLayer",
103 layer: {
104 base:{
105 layerBindingId: 0,
106 base: {
107 index: 3,
108 layerName: "OutputLayer",
109 layerType: "Output",
110 inputSlots: [{
111 index: 0,
112 connection: {sourceLayerIndex:2, outputSlotIndex:0 },
113 }],
114 outputSlots: [ {
115 index: 0,
116 tensorInfo: {
117 dimensions: )" + outputShape + R"(,
118 dataType: )" + dataType + R"(
119 },
120 }],
121 }}},
122 }]
123 } )";
124
125 Setup();
126 }
127};
128
129struct SimpleGatherFixtureFloat32 : GatherFixture
130{
131 SimpleGatherFixtureFloat32() : GatherFixture("[ 3, 2, 3 ]", "[ 2, 3 ]", "[1, 2, 1, 2, 1, 0]",
Teresa Charlin52664732020-06-29 16:27:03 +0100132 "[ 2, 3, 2, 3 ]", "0", "Float32", "IntData") {}
Saoirse Stewarta1ed73a2019-03-04 13:40:12 +0000133};
134
Sadik Armagan1625efc2021-06-10 18:24:34 +0100135TEST_CASE_FIXTURE(SimpleGatherFixtureFloat32, "GatherFloat32")
Saoirse Stewarta1ed73a2019-03-04 13:40:12 +0000136{
137 RunTest<4, armnn::DataType::Float32>(0,
138 {{"InputLayer", { 1, 2, 3,
139 4, 5, 6,
140 7, 8, 9,
141 10, 11, 12,
142 13, 14, 15,
143 16, 17, 18 }}},
144 {{"OutputLayer", { 7, 8, 9,
145 10, 11, 12,
146 13, 14, 15,
147 16, 17, 18,
148 7, 8, 9,
149 10, 11, 12,
150 13, 14, 15,
151 16, 17, 18,
152 7, 8, 9,
153 10, 11, 12,
154 1, 2, 3,
155 4, 5, 6 }}});
156}
157
Sadik Armagan1625efc2021-06-10 18:24:34 +0100158}
Saoirse Stewarta1ed73a2019-03-04 13:40:12 +0000159