blob: 3fdcf51aed2deb606dcc9f501102bb5b4243837a [file] [log] [blame]
Saoirse Stewarta1ed73a2019-03-04 13:40:12 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersSerializeFixture.hpp"
8#include "../Deserializer.hpp"
9
10#include <string>
11#include <iostream>
12
13BOOST_AUTO_TEST_SUITE(Deserializer)
14
15struct GatherFixture : public ParserFlatbuffersSerializeFixture
16{
17 explicit GatherFixture(const std::string &inputShape,
18 const std::string &indicesShape,
19 const std::string &input1Content,
20 const std::string &outputShape,
21 const std::string dataType,
22 const std::string constDataType)
23 {
24 m_JsonString = R"(
25 {
26 inputIds: [0],
27 outputIds: [3],
28 layers: [
29 {
30 layer_type: "InputLayer",
31 layer: {
32 base: {
33 layerBindingId: 0,
34 base: {
35 index: 0,
36 layerName: "InputLayer",
37 layerType: "Input",
38 inputSlots: [{
39 index: 0,
40 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
41 }],
42 outputSlots: [ {
43 index: 0,
44 tensorInfo: {
45 dimensions: )" + inputShape + R"(,
46 dataType: )" + dataType + R"(
47 }}]
48 }
49 }}},
50 {
51 layer_type: "ConstantLayer",
52 layer: {
53 base: {
54 index:1,
55 layerName: "ConstantLayer",
56 layerType: "Constant",
57 outputSlots: [ {
58 index: 0,
59 tensorInfo: {
60 dimensions: )" + indicesShape + R"(,
61 dataType: "Signed32",
62 },
63 }],
64 },
65 input: {
66 info: {
67 dimensions: )" + indicesShape + R"(,
68 dataType: )" + dataType + R"(
69 },
70 data_type: )" + constDataType + R"(,
71 data: {
72 data: )" + input1Content + R"(,
73 } }
74 },},
75 {
76 layer_type: "GatherLayer",
77 layer: {
78 base: {
79 index: 2,
80 layerName: "GatherLayer",
81 layerType: "Gather",
82 inputSlots: [
83 {
84 index: 0,
85 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
86 },
87 {
88 index: 1,
89 connection: {sourceLayerIndex:1, outputSlotIndex:0 }
90 }],
91 outputSlots: [ {
92 index: 0,
93 tensorInfo: {
94 dimensions: )" + outputShape + R"(,
95 dataType: )" + dataType + R"(
96
97 }}]}
98 }},
99 {
100 layer_type: "OutputLayer",
101 layer: {
102 base:{
103 layerBindingId: 0,
104 base: {
105 index: 3,
106 layerName: "OutputLayer",
107 layerType: "Output",
108 inputSlots: [{
109 index: 0,
110 connection: {sourceLayerIndex:2, outputSlotIndex:0 },
111 }],
112 outputSlots: [ {
113 index: 0,
114 tensorInfo: {
115 dimensions: )" + outputShape + R"(,
116 dataType: )" + dataType + R"(
117 },
118 }],
119 }}},
120 }]
121 } )";
122
123 Setup();
124 }
125};
126
127struct SimpleGatherFixtureFloat32 : GatherFixture
128{
129 SimpleGatherFixtureFloat32() : GatherFixture("[ 3, 2, 3 ]", "[ 2, 3 ]", "[1, 2, 1, 2, 1, 0]",
130 "[ 2, 3, 2, 3 ]", "Float32", "IntData") {}
131};
132
133BOOST_FIXTURE_TEST_CASE(GatherFloat32, SimpleGatherFixtureFloat32)
134{
135 RunTest<4, armnn::DataType::Float32>(0,
136 {{"InputLayer", { 1, 2, 3,
137 4, 5, 6,
138 7, 8, 9,
139 10, 11, 12,
140 13, 14, 15,
141 16, 17, 18 }}},
142 {{"OutputLayer", { 7, 8, 9,
143 10, 11, 12,
144 13, 14, 15,
145 16, 17, 18,
146 7, 8, 9,
147 10, 11, 12,
148 13, 14, 15,
149 16, 17, 18,
150 7, 8, 9,
151 10, 11, 12,
152 1, 2, 3,
153 4, 5, 6 }}});
154}
155
156BOOST_AUTO_TEST_SUITE_END()
157