blob: f6de44c95f06049e874ec8ecb69c7ebb4ba90561 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh01bceff2f2018-03-29 16:29:27 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
telsoa01c577f2c2018-08-31 09:22:23 +010012struct Pooling2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010013{
FrancisMurtaghf005e312018-12-06 15:26:04 +000014 explicit Pooling2dFixture(const char* poolingtype, std::string dataLayout, std::string paddingOption)
surmeh01bceff2f2018-03-29 16:29:27 +010015 {
16 m_Prototext = "node {\n"
17 " name: \"Placeholder\"\n"
18 " op: \"Placeholder\"\n"
19 " attr {\n"
20 " key: \"dtype\"\n"
21 " value {\n"
22 " type: DT_FLOAT\n"
23 " }\n"
24 " }\n"
25 " attr {\n"
26 " key: \"value\"\n"
27 " value {\n"
28 " tensor {\n"
29 " dtype: DT_FLOAT\n"
30 " tensor_shape {\n"
31 " }\n"
32 " }\n"
33 " }\n"
34 " }\n"
35 " }\n"
36 "node {\n"
37 " name: \"";
38 m_Prototext.append(poolingtype);
39 m_Prototext.append("\"\n"
40 " op: \"");
41 m_Prototext.append(poolingtype);
42 m_Prototext.append("\"\n"
43 " input: \"Placeholder\"\n"
44 " attr {\n"
45 " key: \"T\"\n"
46 " value {\n"
47 " type: DT_FLOAT\n"
48 " }\n"
49 " }\n"
50 " attr {\n"
51 " key: \"data_format\"\n"
52 " value {\n"
FrancisMurtaghf005e312018-12-06 15:26:04 +000053 " s: \"");
54 m_Prototext.append(dataLayout);
55 m_Prototext.append("\"\n"
surmeh01bceff2f2018-03-29 16:29:27 +010056 " }\n"
57 " }\n"
58 " attr {\n"
59 " key: \"ksize\"\n"
60 " value {\n"
61 " list {\n"
FrancisMurtaghf005e312018-12-06 15:26:04 +000062
63 " i: 1\n");
64 if(dataLayout == "NHWC")
65 {
66 m_Prototext.append(" i: 2\n"
surmeh01bceff2f2018-03-29 16:29:27 +010067 " i: 2\n"
FrancisMurtaghf005e312018-12-06 15:26:04 +000068 " i: 1\n");
69 }
70 else
71 {
72 m_Prototext.append(" i: 1\n"
surmeh01bceff2f2018-03-29 16:29:27 +010073 " i: 2\n"
FrancisMurtaghf005e312018-12-06 15:26:04 +000074 " i: 2\n");
75 }
76 m_Prototext.append(
surmeh01bceff2f2018-03-29 16:29:27 +010077 " }\n"
78 " }\n"
79 " }\n"
80 " attr {\n"
81 " key: \"padding\"\n"
82 " value {\n"
FrancisMurtaghf005e312018-12-06 15:26:04 +000083 " s: \"");
84 m_Prototext.append(paddingOption);
85 m_Prototext.append(
86 "\"\n"
surmeh01bceff2f2018-03-29 16:29:27 +010087 " }\n"
88 " }\n"
89 " attr {\n"
90 " key: \"strides\"\n"
91 " value {\n"
92 " list {\n"
93 " i: 1\n"
94 " i: 1\n"
95 " i: 1\n"
96 " i: 1\n"
97 " }\n"
98 " }\n"
99 " }\n"
100 "}\n");
101
FrancisMurtaghf005e312018-12-06 15:26:04 +0000102 if(dataLayout == "NHWC")
103 {
104 SetupSingleInputSingleOutput({ 1, 2, 2, 1 }, "Placeholder", poolingtype);
105 }
106 else
107 {
108 SetupSingleInputSingleOutput({ 1, 1, 2, 2 }, "Placeholder", poolingtype);
109 }
surmeh01bceff2f2018-03-29 16:29:27 +0100110 }
111};
112
113
FrancisMurtaghf005e312018-12-06 15:26:04 +0000114struct MaxPoolFixtureNhwcValid : Pooling2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100115{
FrancisMurtaghf005e312018-12-06 15:26:04 +0000116 MaxPoolFixtureNhwcValid() : Pooling2dFixture("MaxPool", "NHWC", "VALID") {}
surmeh01bceff2f2018-03-29 16:29:27 +0100117};
FrancisMurtaghf005e312018-12-06 15:26:04 +0000118BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNhwcValid, MaxPoolFixtureNhwcValid)
surmeh01bceff2f2018-03-29 16:29:27 +0100119{
120 RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f});
121}
122
FrancisMurtaghf005e312018-12-06 15:26:04 +0000123struct MaxPoolFixtureNchwValid : Pooling2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100124{
FrancisMurtaghf005e312018-12-06 15:26:04 +0000125 MaxPoolFixtureNchwValid() : Pooling2dFixture("MaxPool", "NCHW", "VALID") {}
surmeh01bceff2f2018-03-29 16:29:27 +0100126};
FrancisMurtaghf005e312018-12-06 15:26:04 +0000127BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNchwValid, MaxPoolFixtureNchwValid)
128{
129 RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f});
130}
131
132struct MaxPoolFixtureNhwcSame : Pooling2dFixture
133{
134 MaxPoolFixtureNhwcSame() : Pooling2dFixture("MaxPool", "NHWC", "SAME") {}
135};
136BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNhwcSame, MaxPoolFixtureNhwcSame)
137{
138 RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f, 2.0f, 3.0f, -4.0f});
139}
140
141struct MaxPoolFixtureNchwSame : Pooling2dFixture
142{
143 MaxPoolFixtureNchwSame() : Pooling2dFixture("MaxPool", "NCHW", "SAME") {}
144};
145BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNchwSame, MaxPoolFixtureNchwSame)
146{
147 RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f, 2.0f, 3.0f, -4.0f});
148}
149
150struct AvgPoolFixtureNhwcValid : Pooling2dFixture
151{
152 AvgPoolFixtureNhwcValid() : Pooling2dFixture("AvgPool", "NHWC", "VALID") {}
153};
154BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNhwcValid, AvgPoolFixtureNhwcValid)
surmeh01bceff2f2018-03-29 16:29:27 +0100155{
156 RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f});
157}
158
FrancisMurtaghf005e312018-12-06 15:26:04 +0000159struct AvgPoolFixtureNchwValid : Pooling2dFixture
160{
161 AvgPoolFixtureNchwValid() : Pooling2dFixture("AvgPool", "NCHW", "VALID") {}
162};
163BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNchwValid, AvgPoolFixtureNchwValid)
164{
165 RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f});
166}
167
168struct AvgPoolFixtureNhwcSame : Pooling2dFixture
169{
170 AvgPoolFixtureNhwcSame() : Pooling2dFixture("AvgPool", "NHWC", "SAME") {}
171};
172BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNhwcSame, AvgPoolFixtureNhwcSame)
173{
174 RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f, 3.0f, 3.5f, 4.0f});
175}
176
177struct AvgPoolFixtureNchwSame : Pooling2dFixture
178{
179 AvgPoolFixtureNchwSame() : Pooling2dFixture("AvgPool", "NCHW", "SAME") {}
180};
181BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNchwSame, AvgPoolFixtureNchwSame)
182{
183 RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f, 3.0f, 3.5f, 4.0f});
184}
surmeh01bceff2f2018-03-29 16:29:27 +0100185
186BOOST_AUTO_TEST_SUITE_END()