blob: b93a4728d0f35292f2ac0a31c7234c012976dc72 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh01bceff2f2018-03-29 16:29:27 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
Matteo Martincigh075c7502018-12-05 13:10:45 +000010#include <array>
11
surmeh01bceff2f2018-03-29 16:29:27 +010012BOOST_AUTO_TEST_SUITE(TensorflowParser)
13
telsoa01c577f2c2018-08-31 09:22:23 +010014struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010015{
Matteo Martincigh075c7502018-12-05 13:10:45 +000016 explicit FusedBatchNormFixture(const std::string& dataLayout)
surmeh01bceff2f2018-03-29 16:29:27 +010017 {
18 m_Prototext = "node { \n"
19 " name: \"graphInput\" \n"
20 " op: \"Placeholder\" \n"
21 " attr { \n"
22 " key: \"dtype\" \n"
23 " value { \n"
24 " type: DT_FLOAT \n"
25 " } \n"
26 " } \n"
27 " attr { \n"
28 " key: \"shape\" \n"
29 " value { \n"
30 " shape { \n"
31 " } \n"
32 " } \n"
33 " } \n"
34 "} \n"
35 "node { \n"
36 " name: \"Const_1\" \n"
37 " op: \"Const\" \n"
38 " attr { \n"
39 " key: \"dtype\" \n"
40 " value { \n"
41 " type: DT_FLOAT \n"
42 " } \n"
43 " } \n"
44 " attr { \n"
45 " key: \"value\" \n"
46 " value { \n"
47 " tensor { \n"
48 " dtype: DT_FLOAT \n"
49 " tensor_shape { \n"
50 " dim { \n"
51 " size: 1 \n"
52 " } \n"
53 " } \n"
54 " float_val: 1.0 \n"
55 " } \n"
56 " } \n"
57 " } \n"
58 "} \n"
59 "node { \n"
60 " name: \"Const_2\" \n"
61 " op: \"Const\" \n"
62 " attr { \n"
63 " key: \"dtype\" \n"
64 " value { \n"
65 " type: DT_FLOAT \n"
66 " } \n"
67 " } \n"
68 " attr { \n"
69 " key: \"value\" \n"
70 " value { \n"
71 " tensor { \n"
72 " dtype: DT_FLOAT \n"
73 " tensor_shape { \n"
74 " dim { \n"
75 " size: 1 \n"
76 " } \n"
77 " } \n"
78 " float_val: 0.0 \n"
79 " } \n"
80 " } \n"
81 " } \n"
82 "} \n"
83 "node { \n"
84 " name: \"FusedBatchNormLayer/mean\" \n"
85 " op: \"Const\" \n"
86 " attr { \n"
87 " key: \"dtype\" \n"
88 " value { \n"
89 " type: DT_FLOAT \n"
90 " } \n"
91 " } \n"
92 " attr { \n"
93 " key: \"value\" \n"
94 " value { \n"
95 " tensor { \n"
96 " dtype: DT_FLOAT \n"
97 " tensor_shape { \n"
98 " dim { \n"
99 " size: 1 \n"
100 " } \n"
101 " } \n"
102 " float_val: 5.0 \n"
103 " } \n"
104 " } \n"
105 " } \n"
106 "} \n"
107 "node { \n"
108 " name: \"FusedBatchNormLayer/variance\" \n"
109 " op: \"Const\" \n"
110 " attr { \n"
111 " key: \"dtype\" \n"
112 " value { \n"
113 " type: DT_FLOAT \n"
114 " } \n"
115 " } \n"
116 " attr { \n"
117 " key: \"value\" \n"
118 " value { \n"
119 " tensor { \n"
120 " dtype: DT_FLOAT \n"
121 " tensor_shape { \n"
122 " dim { \n"
123 " size: 1 \n"
124 " } \n"
125 " } \n"
126 " float_val: 2.0 \n"
127 " } \n"
128 " } \n"
129 " } \n"
130 "} \n"
131 "node { \n"
132 " name: \"output\" \n"
133 " op: \"FusedBatchNorm\" \n"
134 " input: \"graphInput\" \n"
135 " input: \"Const_1\" \n"
136 " input: \"Const_2\" \n"
137 " input: \"FusedBatchNormLayer/mean\" \n"
138 " input: \"FusedBatchNormLayer/variance\" \n"
139 " attr { \n"
140 " key: \"T\" \n"
141 " value { \n"
142 " type: DT_FLOAT \n"
143 " } \n"
Aron Virginas-Tar2e259272019-11-27 13:29:51 +0000144 " } \n";
145
146 // NOTE: we only explicitly set data_format when it is not the default NHWC
147 if (dataLayout != "NHWC")
148 {
149 m_Prototext.append(" attr { \n"
150 " key: \"data_format\" \n"
151 " value { \n"
152 " s: \"");
153 m_Prototext.append(dataLayout);
154 m_Prototext.append("\" \n"
155 " } \n"
156 " } \n");
157 }
158
159 m_Prototext.append(" attr { \n"
Matteo Martincigh075c7502018-12-05 13:10:45 +0000160 " key: \"epsilon\" \n"
161 " value { \n"
162 " f: 0.0010000000475 \n"
163 " } \n"
164 " } \n"
165 " attr { \n"
166 " key: \"is_training\" \n"
167 " value { \n"
168 " b: false \n"
169 " } \n"
170 " } \n"
171 "} \n");
surmeh01bceff2f2018-03-29 16:29:27 +0100172
Matteo Martincigh075c7502018-12-05 13:10:45 +0000173 // Set the input shape according to the data layout
174 std::array<unsigned int, 4> dims;
175 if (dataLayout == "NHWC")
176 {
177 dims = { 1u, 3u, 3u, 1u };
178 }
179 else // dataLayout == "NCHW"
180 {
181 dims = { 1u, 1u, 3u, 3u };
182 }
183
184 SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "output");
surmeh01bceff2f2018-03-29 16:29:27 +0100185 }
186};
187
Matteo Martincigh075c7502018-12-05 13:10:45 +0000188struct FusedBatchNormNhwcFixture : FusedBatchNormFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100189{
Matteo Martincigh075c7502018-12-05 13:10:45 +0000190 FusedBatchNormNhwcFixture() : FusedBatchNormFixture("NHWC"){}
191};
192BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNhwc, FusedBatchNormNhwcFixture)
193{
194 RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // Input data.
195 { -2.8277204f, -2.12079024f, -1.4138602f,
196 -0.7069301f, 0.0f, 0.7069301f,
197 1.4138602f, 2.12079024f, 2.8277204f }); // Expected output data.
198}
199
200struct FusedBatchNormNchwFixture : FusedBatchNormFixture
201{
202 FusedBatchNormNchwFixture() : FusedBatchNormFixture("NCHW"){}
203};
204BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNchw, FusedBatchNormNchwFixture)
205{
206 RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // Input data.
207 { -2.8277204f, -2.12079024f, -1.4138602f,
208 -0.7069301f, 0.0f, 0.7069301f,
209 1.4138602f, 2.12079024f, 2.8277204f }); // Expected output data.
surmeh01bceff2f2018-03-29 16:29:27 +0100210}
211
212BOOST_AUTO_TEST_SUITE_END()