blob: 98bdb261836c61c9625e381302169475f31ca3ac [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh01bceff2f2018-03-29 16:29:27 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
Matteo Martincigh075c7502018-12-05 13:10:45 +000010#include <array>
11
surmeh01bceff2f2018-03-29 16:29:27 +010012BOOST_AUTO_TEST_SUITE(TensorflowParser)
13
telsoa01c577f2c2018-08-31 09:22:23 +010014struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010015{
Matteo Martincigh075c7502018-12-05 13:10:45 +000016 explicit FusedBatchNormFixture(const std::string& dataLayout)
surmeh01bceff2f2018-03-29 16:29:27 +010017 {
18 m_Prototext = "node { \n"
19 " name: \"graphInput\" \n"
20 " op: \"Placeholder\" \n"
21 " attr { \n"
22 " key: \"dtype\" \n"
23 " value { \n"
24 " type: DT_FLOAT \n"
25 " } \n"
26 " } \n"
27 " attr { \n"
28 " key: \"shape\" \n"
29 " value { \n"
30 " shape { \n"
31 " } \n"
32 " } \n"
33 " } \n"
34 "} \n"
35 "node { \n"
36 " name: \"Const_1\" \n"
37 " op: \"Const\" \n"
38 " attr { \n"
39 " key: \"dtype\" \n"
40 " value { \n"
41 " type: DT_FLOAT \n"
42 " } \n"
43 " } \n"
44 " attr { \n"
45 " key: \"value\" \n"
46 " value { \n"
47 " tensor { \n"
48 " dtype: DT_FLOAT \n"
49 " tensor_shape { \n"
50 " dim { \n"
51 " size: 1 \n"
52 " } \n"
53 " } \n"
54 " float_val: 1.0 \n"
55 " } \n"
56 " } \n"
57 " } \n"
58 "} \n"
59 "node { \n"
60 " name: \"Const_2\" \n"
61 " op: \"Const\" \n"
62 " attr { \n"
63 " key: \"dtype\" \n"
64 " value { \n"
65 " type: DT_FLOAT \n"
66 " } \n"
67 " } \n"
68 " attr { \n"
69 " key: \"value\" \n"
70 " value { \n"
71 " tensor { \n"
72 " dtype: DT_FLOAT \n"
73 " tensor_shape { \n"
74 " dim { \n"
75 " size: 1 \n"
76 " } \n"
77 " } \n"
78 " float_val: 0.0 \n"
79 " } \n"
80 " } \n"
81 " } \n"
82 "} \n"
83 "node { \n"
84 " name: \"FusedBatchNormLayer/mean\" \n"
85 " op: \"Const\" \n"
86 " attr { \n"
87 " key: \"dtype\" \n"
88 " value { \n"
89 " type: DT_FLOAT \n"
90 " } \n"
91 " } \n"
92 " attr { \n"
93 " key: \"value\" \n"
94 " value { \n"
95 " tensor { \n"
96 " dtype: DT_FLOAT \n"
97 " tensor_shape { \n"
98 " dim { \n"
99 " size: 1 \n"
100 " } \n"
101 " } \n"
102 " float_val: 5.0 \n"
103 " } \n"
104 " } \n"
105 " } \n"
106 "} \n"
107 "node { \n"
108 " name: \"FusedBatchNormLayer/variance\" \n"
109 " op: \"Const\" \n"
110 " attr { \n"
111 " key: \"dtype\" \n"
112 " value { \n"
113 " type: DT_FLOAT \n"
114 " } \n"
115 " } \n"
116 " attr { \n"
117 " key: \"value\" \n"
118 " value { \n"
119 " tensor { \n"
120 " dtype: DT_FLOAT \n"
121 " tensor_shape { \n"
122 " dim { \n"
123 " size: 1 \n"
124 " } \n"
125 " } \n"
126 " float_val: 2.0 \n"
127 " } \n"
128 " } \n"
129 " } \n"
130 "} \n"
131 "node { \n"
132 " name: \"output\" \n"
133 " op: \"FusedBatchNorm\" \n"
134 " input: \"graphInput\" \n"
135 " input: \"Const_1\" \n"
136 " input: \"Const_2\" \n"
137 " input: \"FusedBatchNormLayer/mean\" \n"
138 " input: \"FusedBatchNormLayer/variance\" \n"
139 " attr { \n"
140 " key: \"T\" \n"
141 " value { \n"
142 " type: DT_FLOAT \n"
143 " } \n"
144 " } \n"
145 " attr { \n"
146 " key: \"data_format\" \n"
147 " value { \n"
Matteo Martincigh075c7502018-12-05 13:10:45 +0000148 " s: \"";
149 m_Prototext.append(dataLayout);
150 m_Prototext.append("\" \n"
151 " } \n"
152 " } \n"
153 " attr { \n"
154 " key: \"epsilon\" \n"
155 " value { \n"
156 " f: 0.0010000000475 \n"
157 " } \n"
158 " } \n"
159 " attr { \n"
160 " key: \"is_training\" \n"
161 " value { \n"
162 " b: false \n"
163 " } \n"
164 " } \n"
165 "} \n");
surmeh01bceff2f2018-03-29 16:29:27 +0100166
Matteo Martincigh075c7502018-12-05 13:10:45 +0000167 // Set the input shape according to the data layout
168 std::array<unsigned int, 4> dims;
169 if (dataLayout == "NHWC")
170 {
171 dims = { 1u, 3u, 3u, 1u };
172 }
173 else // dataLayout == "NCHW"
174 {
175 dims = { 1u, 1u, 3u, 3u };
176 }
177
178 SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "output");
surmeh01bceff2f2018-03-29 16:29:27 +0100179 }
180};
181
Matteo Martincigh075c7502018-12-05 13:10:45 +0000182struct FusedBatchNormNhwcFixture : FusedBatchNormFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100183{
Matteo Martincigh075c7502018-12-05 13:10:45 +0000184 FusedBatchNormNhwcFixture() : FusedBatchNormFixture("NHWC"){}
185};
186BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNhwc, FusedBatchNormNhwcFixture)
187{
188 RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // Input data.
189 { -2.8277204f, -2.12079024f, -1.4138602f,
190 -0.7069301f, 0.0f, 0.7069301f,
191 1.4138602f, 2.12079024f, 2.8277204f }); // Expected output data.
192}
193
194struct FusedBatchNormNchwFixture : FusedBatchNormFixture
195{
196 FusedBatchNormNchwFixture() : FusedBatchNormFixture("NCHW"){}
197};
198BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNchw, FusedBatchNormNchwFixture)
199{
200 RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // Input data.
201 { -2.8277204f, -2.12079024f, -1.4138602f,
202 -0.7069301f, 0.0f, 0.7069301f,
203 1.4138602f, 2.12079024f, 2.8277204f }); // Expected output data.
surmeh01bceff2f2018-03-29 16:29:27 +0100204}
205
206BOOST_AUTO_TEST_SUITE_END()