blob: f603b22afd23c5e6dcc545628aee00ed53acb40e [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
telsoa01c577f2c2018-08-31 09:22:23 +010012struct Pooling2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010013{
14 explicit Pooling2dFixture(const char* poolingtype)
15 {
16 m_Prototext = "node {\n"
17 " name: \"Placeholder\"\n"
18 " op: \"Placeholder\"\n"
19 " attr {\n"
20 " key: \"dtype\"\n"
21 " value {\n"
22 " type: DT_FLOAT\n"
23 " }\n"
24 " }\n"
25 " attr {\n"
26 " key: \"value\"\n"
27 " value {\n"
28 " tensor {\n"
29 " dtype: DT_FLOAT\n"
30 " tensor_shape {\n"
31 " }\n"
32 " }\n"
33 " }\n"
34 " }\n"
35 " }\n"
36 "node {\n"
37 " name: \"";
38 m_Prototext.append(poolingtype);
39 m_Prototext.append("\"\n"
40 " op: \"");
41 m_Prototext.append(poolingtype);
42 m_Prototext.append("\"\n"
43 " input: \"Placeholder\"\n"
44 " attr {\n"
45 " key: \"T\"\n"
46 " value {\n"
47 " type: DT_FLOAT\n"
48 " }\n"
49 " }\n"
50 " attr {\n"
51 " key: \"data_format\"\n"
52 " value {\n"
53 " s: \"NHWC\"\n"
54 " }\n"
55 " }\n"
56 " attr {\n"
57 " key: \"ksize\"\n"
58 " value {\n"
59 " list {\n"
60 " i: 1\n"
61 " i: 2\n"
62 " i: 2\n"
63 " i: 1\n"
64 " }\n"
65 " }\n"
66 " }\n"
67 " attr {\n"
68 " key: \"padding\"\n"
69 " value {\n"
70 " s: \"VALID\"\n"
71 " }\n"
72 " }\n"
73 " attr {\n"
74 " key: \"strides\"\n"
75 " value {\n"
76 " list {\n"
77 " i: 1\n"
78 " i: 1\n"
79 " i: 1\n"
80 " i: 1\n"
81 " }\n"
82 " }\n"
83 " }\n"
84 "}\n");
85
86 SetupSingleInputSingleOutput({ 1, 2, 2, 1 }, "Placeholder", poolingtype);
87 }
88};
89
90
91struct MaxPoolFixture : Pooling2dFixture
92{
93 MaxPoolFixture() : Pooling2dFixture("MaxPool") {}
94};
95BOOST_FIXTURE_TEST_CASE(ParseMaxPool, MaxPoolFixture)
96{
97 RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f});
98}
99
100
101struct AvgPoolFixture : Pooling2dFixture
102{
103 AvgPoolFixture() : Pooling2dFixture("AvgPool") {}
104};
105BOOST_FIXTURE_TEST_CASE(ParseAvgPool, AvgPoolFixture)
106{
107 RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f});
108}
109
110
111BOOST_AUTO_TEST_SUITE_END()