blob: a7d5ea03afae40fdcd5965411e5880ac522a525e [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12struct ConcatFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
13{
14 explicit ConcatFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1,
15 unsigned int concatDim)
16 {
17 m_Prototext = R"(
18 node {
19 name: "graphInput0"
20 op: "Placeholder"
21 attr {
22 key: "dtype"
23 value {
24 type: DT_FLOAT
25 }
26 }
27 attr {
28 key: "shape"
29 value {
30 shape {
31 }
32 }
33 }
34 }
35 node {
36 name: "graphInput1"
37 op: "Placeholder"
38 attr {
39 key: "dtype"
40 value {
41 type: DT_FLOAT
42 }
43 }
44 attr {
45 key: "shape"
46 value {
47 shape {
48 }
49 }
50 }
51 }
52 node {
53 name: "concat/axis"
54 op: "Const"
55 attr {
56 key: "dtype"
57 value {
58 type: DT_INT32
59 }
60 }
61 attr {
62 key: "value"
63 value {
64 tensor {
65 dtype: DT_INT32
66 tensor_shape {
67 }
68 int_val: )";
69
70 m_Prototext += std::to_string(concatDim);
71
72 m_Prototext += R"(
73 }
74 }
75 }
76 }
77 node {
78 name: "concat"
79 op: "ConcatV2"
80 input: "graphInput0"
81 input: "graphInput1"
82 input: "concat/axis"
83 attr {
84 key: "N"
85 value {
86 i: 2
87 }
88 }
89 attr {
90 key: "T"
91 value {
92 type: DT_FLOAT
93 }
94 }
95 attr {
96 key: "Tidx"
97 value {
98 type: DT_FLOAT
99 }
100 }
101 }
102 )";
103
104 Setup({{"graphInput0", inputShape0 },
105 {"graphInput1", inputShape1 }}, {"concat"});
106 }
107};
108
109struct ConcatFixtureNCHW : ConcatFixture
110{
111 ConcatFixtureNCHW() : ConcatFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, 1 ) {}
112};
113
114struct ConcatFixtureNHWC : ConcatFixture
115{
116 ConcatFixtureNHWC() : ConcatFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, 3 ) {}
117};
118
119BOOST_FIXTURE_TEST_CASE(ParseConcatNCHW, ConcatFixtureNCHW)
120{
121 RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
122 {"graphInput1", {4.0, 5.0, 6.0, 7.0}}},
123 {{"concat", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0 }}});
124}
125
126BOOST_FIXTURE_TEST_CASE(ParseConcatNHWC, ConcatFixtureNHWC)
127{
128 RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
129 {"graphInput1", {4.0, 5.0, 6.0, 7.0}}},
130 {{"concat", { 0.0, 1.0, 4.0, 5.0, 2.0, 3.0, 6.0, 7.0 }}});
131}
132
133struct ConcatFixtureDim1 : ConcatFixture
134{
135 ConcatFixtureDim1() : ConcatFixture({ 1, 2, 3, 4 }, { 1, 2, 3, 4 }, 1) {}
136};
137
138struct ConcatFixtureDim3 : ConcatFixture
139{
140 ConcatFixtureDim3() : ConcatFixture({ 1, 2, 3, 4 }, { 1, 2, 3, 4 }, 3) {}
141};
142
143BOOST_FIXTURE_TEST_CASE(ParseConcatDim1, ConcatFixtureDim1)
144{
145 RunTest<4>({ { "graphInput0", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
146 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0 } },
147 { "graphInput1", { 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0,
148 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0 } } },
149 { { "concat", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
150 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0,
151 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0,
152 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0 } } });
153}
154
155BOOST_FIXTURE_TEST_CASE(ParseConcatDim3, ConcatFixtureDim3)
156{
157 RunTest<4>({ { "graphInput0", { 0.0, 1.0, 2.0, 3.0,
158 4.0, 5.0, 6.0, 7.0,
159 8.0, 9.0, 10.0, 11.0,
160 12.0, 13.0, 14.0, 15.0,
161 16.0, 17.0, 18.0, 19.0,
162 20.0, 21.0, 22.0, 23.0 } },
163 { "graphInput1", { 50.0, 51.0, 52.0, 53.0,
164 54.0, 55.0, 56.0, 57.0,
165 58.0, 59.0, 60.0, 61.0,
166 62.0, 63.0, 64.0, 65.0,
167 66.0, 67.0, 68.0, 69.0,
168 70.0, 71.0, 72.0, 73.0 } } },
169 { { "concat", { 0.0, 1.0, 2.0, 3.0,
170 50.0, 51.0, 52.0, 53.0,
171 4.0, 5.0, 6.0, 7.0,
172 54.0, 55.0, 56.0, 57.0,
173 8.0, 9.0, 10.0, 11.0,
174 58.0, 59.0, 60.0, 61.0,
175 12.0, 13.0, 14.0, 15.0,
176 62.0, 63.0, 64.0, 65.0,
177 16.0, 17.0, 18.0, 19.0,
178 66.0, 67.0, 68.0, 69.0,
179 20.0, 21.0, 22.0, 23.0,
180 70.0, 71.0, 72.0, 73.0 } } });
181}
182
183BOOST_AUTO_TEST_SUITE_END()