blob: 69f018f194fc4521ed89de11496dcd001a124513 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
telsoa01c577f2c2018-08-31 09:22:23 +010012struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010013{
14 FusedBatchNormFixture()
15 {
16 m_Prototext = "node { \n"
17 " name: \"graphInput\" \n"
18 " op: \"Placeholder\" \n"
19 " attr { \n"
20 " key: \"dtype\" \n"
21 " value { \n"
22 " type: DT_FLOAT \n"
23 " } \n"
24 " } \n"
25 " attr { \n"
26 " key: \"shape\" \n"
27 " value { \n"
28 " shape { \n"
29 " } \n"
30 " } \n"
31 " } \n"
32 "} \n"
33 "node { \n"
34 " name: \"Const_1\" \n"
35 " op: \"Const\" \n"
36 " attr { \n"
37 " key: \"dtype\" \n"
38 " value { \n"
39 " type: DT_FLOAT \n"
40 " } \n"
41 " } \n"
42 " attr { \n"
43 " key: \"value\" \n"
44 " value { \n"
45 " tensor { \n"
46 " dtype: DT_FLOAT \n"
47 " tensor_shape { \n"
48 " dim { \n"
49 " size: 1 \n"
50 " } \n"
51 " } \n"
52 " float_val: 1.0 \n"
53 " } \n"
54 " } \n"
55 " } \n"
56 "} \n"
57 "node { \n"
58 " name: \"Const_2\" \n"
59 " op: \"Const\" \n"
60 " attr { \n"
61 " key: \"dtype\" \n"
62 " value { \n"
63 " type: DT_FLOAT \n"
64 " } \n"
65 " } \n"
66 " attr { \n"
67 " key: \"value\" \n"
68 " value { \n"
69 " tensor { \n"
70 " dtype: DT_FLOAT \n"
71 " tensor_shape { \n"
72 " dim { \n"
73 " size: 1 \n"
74 " } \n"
75 " } \n"
76 " float_val: 0.0 \n"
77 " } \n"
78 " } \n"
79 " } \n"
80 "} \n"
81 "node { \n"
82 " name: \"FusedBatchNormLayer/mean\" \n"
83 " op: \"Const\" \n"
84 " attr { \n"
85 " key: \"dtype\" \n"
86 " value { \n"
87 " type: DT_FLOAT \n"
88 " } \n"
89 " } \n"
90 " attr { \n"
91 " key: \"value\" \n"
92 " value { \n"
93 " tensor { \n"
94 " dtype: DT_FLOAT \n"
95 " tensor_shape { \n"
96 " dim { \n"
97 " size: 1 \n"
98 " } \n"
99 " } \n"
100 " float_val: 5.0 \n"
101 " } \n"
102 " } \n"
103 " } \n"
104 "} \n"
105 "node { \n"
106 " name: \"FusedBatchNormLayer/variance\" \n"
107 " op: \"Const\" \n"
108 " attr { \n"
109 " key: \"dtype\" \n"
110 " value { \n"
111 " type: DT_FLOAT \n"
112 " } \n"
113 " } \n"
114 " attr { \n"
115 " key: \"value\" \n"
116 " value { \n"
117 " tensor { \n"
118 " dtype: DT_FLOAT \n"
119 " tensor_shape { \n"
120 " dim { \n"
121 " size: 1 \n"
122 " } \n"
123 " } \n"
124 " float_val: 2.0 \n"
125 " } \n"
126 " } \n"
127 " } \n"
128 "} \n"
129 "node { \n"
130 " name: \"output\" \n"
131 " op: \"FusedBatchNorm\" \n"
132 " input: \"graphInput\" \n"
133 " input: \"Const_1\" \n"
134 " input: \"Const_2\" \n"
135 " input: \"FusedBatchNormLayer/mean\" \n"
136 " input: \"FusedBatchNormLayer/variance\" \n"
137 " attr { \n"
138 " key: \"T\" \n"
139 " value { \n"
140 " type: DT_FLOAT \n"
141 " } \n"
142 " } \n"
143 " attr { \n"
144 " key: \"data_format\" \n"
145 " value { \n"
146 " s: \"NHWC\" \n"
147 " } \n"
148 " } \n"
149 " attr { \n"
150 " key: \"epsilon\" \n"
151 " value { \n"
152 " f: 0.0010000000475 \n"
153 " } \n"
154 " } \n"
155 " attr { \n"
156 " key: \"is_training\" \n"
157 " value { \n"
158 " b: false \n"
159 " } \n"
160 " } \n"
161 "} \n";
162
163 SetupSingleInputSingleOutput({1, 3, 3, 1}, "graphInput", "output");
164 }
165};
166
167BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNorm, FusedBatchNormFixture)
168{
telsoa01c577f2c2018-08-31 09:22:23 +0100169 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, // Input data.
surmeh01bceff2f2018-03-29 16:29:27 +0100170 {-2.8277204f, -2.12079024f, -1.4138602f,
171 -0.7069301f, 0.0f, 0.7069301f,
telsoa01c577f2c2018-08-31 09:22:23 +0100172 1.4138602f, 2.12079024f, 2.8277204f}); // Expected output data.
surmeh01bceff2f2018-03-29 16:29:27 +0100173}
174
175BOOST_AUTO_TEST_SUITE_END()