blob: 36ffa47def4ab24232fea6d5e3cfede85f910fe9 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12
13struct Pooling2dFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
14{
15 explicit Pooling2dFixture(const char* poolingtype)
16 {
17 m_Prototext = "node {\n"
18 " name: \"Placeholder\"\n"
19 " op: \"Placeholder\"\n"
20 " attr {\n"
21 " key: \"dtype\"\n"
22 " value {\n"
23 " type: DT_FLOAT\n"
24 " }\n"
25 " }\n"
26 " attr {\n"
27 " key: \"value\"\n"
28 " value {\n"
29 " tensor {\n"
30 " dtype: DT_FLOAT\n"
31 " tensor_shape {\n"
32 " }\n"
33 " }\n"
34 " }\n"
35 " }\n"
36 " }\n"
37 "node {\n"
38 " name: \"";
39 m_Prototext.append(poolingtype);
40 m_Prototext.append("\"\n"
41 " op: \"");
42 m_Prototext.append(poolingtype);
43 m_Prototext.append("\"\n"
44 " input: \"Placeholder\"\n"
45 " attr {\n"
46 " key: \"T\"\n"
47 " value {\n"
48 " type: DT_FLOAT\n"
49 " }\n"
50 " }\n"
51 " attr {\n"
52 " key: \"data_format\"\n"
53 " value {\n"
54 " s: \"NHWC\"\n"
55 " }\n"
56 " }\n"
57 " attr {\n"
58 " key: \"ksize\"\n"
59 " value {\n"
60 " list {\n"
61 " i: 1\n"
62 " i: 2\n"
63 " i: 2\n"
64 " i: 1\n"
65 " }\n"
66 " }\n"
67 " }\n"
68 " attr {\n"
69 " key: \"padding\"\n"
70 " value {\n"
71 " s: \"VALID\"\n"
72 " }\n"
73 " }\n"
74 " attr {\n"
75 " key: \"strides\"\n"
76 " value {\n"
77 " list {\n"
78 " i: 1\n"
79 " i: 1\n"
80 " i: 1\n"
81 " i: 1\n"
82 " }\n"
83 " }\n"
84 " }\n"
85 "}\n");
86
87 SetupSingleInputSingleOutput({ 1, 2, 2, 1 }, "Placeholder", poolingtype);
88 }
89};
90
91
92struct MaxPoolFixture : Pooling2dFixture
93{
94 MaxPoolFixture() : Pooling2dFixture("MaxPool") {}
95};
96BOOST_FIXTURE_TEST_CASE(ParseMaxPool, MaxPoolFixture)
97{
98 RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f});
99}
100
101
102struct AvgPoolFixture : Pooling2dFixture
103{
104 AvgPoolFixture() : Pooling2dFixture("AvgPool") {}
105};
106BOOST_FIXTURE_TEST_CASE(ParseAvgPool, AvgPoolFixture)
107{
108 RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f});
109}
110
111
112BOOST_AUTO_TEST_SUITE_END()