surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
David Beck | ecb56cd | 2018-09-05 12:52:57 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 4 | // |
| 5 | |
| 6 | #include <boost/test/unit_test.hpp> |
| 7 | #include "armnnTfParser/ITfParser.hpp" |
| 8 | #include "ParserPrototxtFixture.hpp" |
| 9 | |
| 10 | BOOST_AUTO_TEST_SUITE(TensorflowParser) |
| 11 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 12 | struct Pooling2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser> |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 13 | { |
FrancisMurtagh | f005e31 | 2018-12-06 15:26:04 +0000 | [diff] [blame] | 14 | explicit Pooling2dFixture(const char* poolingtype, std::string dataLayout, std::string paddingOption) |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 15 | { |
| 16 | m_Prototext = "node {\n" |
| 17 | " name: \"Placeholder\"\n" |
| 18 | " op: \"Placeholder\"\n" |
| 19 | " attr {\n" |
| 20 | " key: \"dtype\"\n" |
| 21 | " value {\n" |
| 22 | " type: DT_FLOAT\n" |
| 23 | " }\n" |
| 24 | " }\n" |
| 25 | " attr {\n" |
| 26 | " key: \"value\"\n" |
| 27 | " value {\n" |
| 28 | " tensor {\n" |
| 29 | " dtype: DT_FLOAT\n" |
| 30 | " tensor_shape {\n" |
| 31 | " }\n" |
| 32 | " }\n" |
| 33 | " }\n" |
| 34 | " }\n" |
| 35 | " }\n" |
| 36 | "node {\n" |
| 37 | " name: \""; |
| 38 | m_Prototext.append(poolingtype); |
| 39 | m_Prototext.append("\"\n" |
| 40 | " op: \""); |
| 41 | m_Prototext.append(poolingtype); |
| 42 | m_Prototext.append("\"\n" |
| 43 | " input: \"Placeholder\"\n" |
| 44 | " attr {\n" |
| 45 | " key: \"T\"\n" |
| 46 | " value {\n" |
| 47 | " type: DT_FLOAT\n" |
| 48 | " }\n" |
| 49 | " }\n" |
| 50 | " attr {\n" |
| 51 | " key: \"data_format\"\n" |
| 52 | " value {\n" |
FrancisMurtagh | f005e31 | 2018-12-06 15:26:04 +0000 | [diff] [blame] | 53 | " s: \""); |
| 54 | m_Prototext.append(dataLayout); |
| 55 | m_Prototext.append("\"\n" |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 56 | " }\n" |
| 57 | " }\n" |
| 58 | " attr {\n" |
| 59 | " key: \"ksize\"\n" |
| 60 | " value {\n" |
| 61 | " list {\n" |
FrancisMurtagh | f005e31 | 2018-12-06 15:26:04 +0000 | [diff] [blame] | 62 | |
| 63 | " i: 1\n"); |
| 64 | if(dataLayout == "NHWC") |
| 65 | { |
| 66 | m_Prototext.append(" i: 2\n" |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 67 | " i: 2\n" |
FrancisMurtagh | f005e31 | 2018-12-06 15:26:04 +0000 | [diff] [blame] | 68 | " i: 1\n"); |
| 69 | } |
| 70 | else |
| 71 | { |
| 72 | m_Prototext.append(" i: 1\n" |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 73 | " i: 2\n" |
FrancisMurtagh | f005e31 | 2018-12-06 15:26:04 +0000 | [diff] [blame] | 74 | " i: 2\n"); |
| 75 | } |
| 76 | m_Prototext.append( |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 77 | " }\n" |
| 78 | " }\n" |
| 79 | " }\n" |
| 80 | " attr {\n" |
| 81 | " key: \"padding\"\n" |
| 82 | " value {\n" |
FrancisMurtagh | f005e31 | 2018-12-06 15:26:04 +0000 | [diff] [blame] | 83 | " s: \""); |
| 84 | m_Prototext.append(paddingOption); |
| 85 | m_Prototext.append( |
| 86 | "\"\n" |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 87 | " }\n" |
| 88 | " }\n" |
| 89 | " attr {\n" |
| 90 | " key: \"strides\"\n" |
| 91 | " value {\n" |
| 92 | " list {\n" |
| 93 | " i: 1\n" |
| 94 | " i: 1\n" |
| 95 | " i: 1\n" |
| 96 | " i: 1\n" |
| 97 | " }\n" |
| 98 | " }\n" |
| 99 | " }\n" |
| 100 | "}\n"); |
| 101 | |
FrancisMurtagh | f005e31 | 2018-12-06 15:26:04 +0000 | [diff] [blame] | 102 | if(dataLayout == "NHWC") |
| 103 | { |
| 104 | SetupSingleInputSingleOutput({ 1, 2, 2, 1 }, "Placeholder", poolingtype); |
| 105 | } |
| 106 | else |
| 107 | { |
| 108 | SetupSingleInputSingleOutput({ 1, 1, 2, 2 }, "Placeholder", poolingtype); |
| 109 | } |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 110 | } |
| 111 | }; |
| 112 | |
| 113 | |
FrancisMurtagh | f005e31 | 2018-12-06 15:26:04 +0000 | [diff] [blame] | 114 | struct MaxPoolFixtureNhwcValid : Pooling2dFixture |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 115 | { |
FrancisMurtagh | f005e31 | 2018-12-06 15:26:04 +0000 | [diff] [blame] | 116 | MaxPoolFixtureNhwcValid() : Pooling2dFixture("MaxPool", "NHWC", "VALID") {} |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 117 | }; |
FrancisMurtagh | f005e31 | 2018-12-06 15:26:04 +0000 | [diff] [blame] | 118 | BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNhwcValid, MaxPoolFixtureNhwcValid) |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 119 | { |
| 120 | RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f}); |
| 121 | } |
| 122 | |
FrancisMurtagh | f005e31 | 2018-12-06 15:26:04 +0000 | [diff] [blame] | 123 | struct MaxPoolFixtureNchwValid : Pooling2dFixture |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 124 | { |
FrancisMurtagh | f005e31 | 2018-12-06 15:26:04 +0000 | [diff] [blame] | 125 | MaxPoolFixtureNchwValid() : Pooling2dFixture("MaxPool", "NCHW", "VALID") {} |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 126 | }; |
FrancisMurtagh | f005e31 | 2018-12-06 15:26:04 +0000 | [diff] [blame] | 127 | BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNchwValid, MaxPoolFixtureNchwValid) |
| 128 | { |
| 129 | RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f}); |
| 130 | } |
| 131 | |
| 132 | struct MaxPoolFixtureNhwcSame : Pooling2dFixture |
| 133 | { |
| 134 | MaxPoolFixtureNhwcSame() : Pooling2dFixture("MaxPool", "NHWC", "SAME") {} |
| 135 | }; |
| 136 | BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNhwcSame, MaxPoolFixtureNhwcSame) |
| 137 | { |
| 138 | RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f, 2.0f, 3.0f, -4.0f}); |
| 139 | } |
| 140 | |
| 141 | struct MaxPoolFixtureNchwSame : Pooling2dFixture |
| 142 | { |
| 143 | MaxPoolFixtureNchwSame() : Pooling2dFixture("MaxPool", "NCHW", "SAME") {} |
| 144 | }; |
| 145 | BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNchwSame, MaxPoolFixtureNchwSame) |
| 146 | { |
| 147 | RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f, 2.0f, 3.0f, -4.0f}); |
| 148 | } |
| 149 | |
| 150 | struct AvgPoolFixtureNhwcValid : Pooling2dFixture |
| 151 | { |
| 152 | AvgPoolFixtureNhwcValid() : Pooling2dFixture("AvgPool", "NHWC", "VALID") {} |
| 153 | }; |
| 154 | BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNhwcValid, AvgPoolFixtureNhwcValid) |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 155 | { |
| 156 | RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f}); |
| 157 | } |
| 158 | |
FrancisMurtagh | f005e31 | 2018-12-06 15:26:04 +0000 | [diff] [blame] | 159 | struct AvgPoolFixtureNchwValid : Pooling2dFixture |
| 160 | { |
| 161 | AvgPoolFixtureNchwValid() : Pooling2dFixture("AvgPool", "NCHW", "VALID") {} |
| 162 | }; |
| 163 | BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNchwValid, AvgPoolFixtureNchwValid) |
| 164 | { |
| 165 | RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f}); |
| 166 | } |
| 167 | |
| 168 | struct AvgPoolFixtureNhwcSame : Pooling2dFixture |
| 169 | { |
| 170 | AvgPoolFixtureNhwcSame() : Pooling2dFixture("AvgPool", "NHWC", "SAME") {} |
| 171 | }; |
| 172 | BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNhwcSame, AvgPoolFixtureNhwcSame) |
| 173 | { |
| 174 | RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f, 3.0f, 3.5f, 4.0f}); |
| 175 | } |
| 176 | |
| 177 | struct AvgPoolFixtureNchwSame : Pooling2dFixture |
| 178 | { |
| 179 | AvgPoolFixtureNchwSame() : Pooling2dFixture("AvgPool", "NCHW", "SAME") {} |
| 180 | }; |
| 181 | BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNchwSame, AvgPoolFixtureNchwSame) |
| 182 | { |
| 183 | RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f, 3.0f, 3.5f, 4.0f}); |
| 184 | } |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 185 | |
| 186 | BOOST_AUTO_TEST_SUITE_END() |