blob: 43a7ebc28ef0de6fd4d65c875c005886bf475127 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh01bceff2f2018-03-29 16:29:27 +01004//
5
surmeh01bceff2f2018-03-29 16:29:27 +01006#include "ParserPrototxtFixture.hpp"
Matteo Martincighe011d202019-11-28 11:35:47 +00007
8#include "armnnTfParser/ITfParser.hpp"
9
10#include <armnnUtils/Permute.hpp>
11
12#include <boost/test/unit_test.hpp>
13
surmeh01bceff2f2018-03-29 16:29:27 +010014#include <string>
15#include <iostream>
16
Ferran Balaguer6a669d72018-12-11 10:29:05 +000017using namespace armnnUtils;
18using namespace armnn;
19
surmeh01bceff2f2018-03-29 16:29:27 +010020BOOST_AUTO_TEST_SUITE(TensorflowParser)
21
telsoa01c577f2c2018-08-31 09:22:23 +010022struct DepthwiseConvolution2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010023{
Ferran Balaguer6a669d72018-12-11 10:29:05 +000024 explicit DepthwiseConvolution2dFixture(const std::string& dataLayout, const char* paddingType)
surmeh01bceff2f2018-03-29 16:29:27 +010025 {
26 m_Prototext = "node { \n"
27 " name: \"graphInput\" \n"
28 " op: \"Placeholder\" \n"
29 " attr { \n"
30 " key: \"dtype\" \n"
31 " value { \n"
32 " type: DT_FLOAT \n"
33 " } \n"
34 " } \n"
35 " attr { \n"
Ferran Balaguer6a669d72018-12-11 10:29:05 +000036 " key: \"shape\" \n"
surmeh01bceff2f2018-03-29 16:29:27 +010037 " value { \n"
Ferran Balaguer6a669d72018-12-11 10:29:05 +000038 " shape { \n"
surmeh01bceff2f2018-03-29 16:29:27 +010039 " } \n"
40 " } \n"
41 " } \n"
42 " } \n"
43 " node { \n"
44 " name: \"Const_1\" \n"
45 " op: \"Const\" \n"
46 " attr { \n"
47 " key: \"dtype\" \n"
48 " value { \n"
49 " type: DT_FLOAT \n"
50 " } \n"
51 " } \n"
52 " attr { \n"
53 " key: \"value\" \n"
54 " value { \n"
55 " tensor { \n"
56 " dtype: DT_FLOAT \n"
57 " tensor_shape { \n"
58 " dim { \n"
59 " size: 1 \n"
60 " } \n"
61 " dim { \n"
62 " size: 3 \n"
63 " } \n"
64 " dim { \n"
65 " size: 3 \n"
66 " } \n"
67 " dim { \n"
68 " size: 3 \n"
69 " } \n"
70 " } \n"
71 " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
72 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
73 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
74 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
75 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
76 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
77 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
78 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
79 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n"
80 " } \n"
81 " } \n"
82 " } \n"
83 "} \n"
84 "node { \n"
85 " name: \"potato\" \n"
86 " op: \"DepthwiseConv2dNative\" \n"
87 " input: \"graphInput\" \n"
88 " input: \"Const_1\" \n"
89 " attr { \n"
90 " key: \"T\" \n"
91 " value { \n"
92 " type: DT_FLOAT \n"
93 " } \n"
94 " } \n"
95 " attr { \n"
96 " key: \"data_format\" \n"
97 " value { \n"
Ferran Balaguer6a669d72018-12-11 10:29:05 +000098 " s: \"";
99 m_Prototext.append(dataLayout);
100 m_Prototext.append("\"\n"
surmeh01bceff2f2018-03-29 16:29:27 +0100101 " } \n"
102 " } \n"
103 " attr { \n"
104 " key: \"padding\" \n"
105 " value { \n"
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000106 " s: \"");
surmeh01bceff2f2018-03-29 16:29:27 +0100107 m_Prototext.append(paddingType);
108 m_Prototext.append("\"\n"
109 " } \n"
110 " } \n"
111 " attr { \n"
112 " key: \"strides\" \n"
113 " value { \n"
114 " list { \n"
115 " i: 1 \n"
116 " i: 1 \n"
117 " i: 1 \n"
118 " i: 1 \n"
119 " } \n"
120 " } \n"
121 " } \n"
122 " attr { \n"
123 " key: \"use_cudnn_on_gpu\" \n"
124 " value { \n"
125 " b: false \n"
126 " } \n"
127 " } \n"
128 "} \n");
129
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000130 if(dataLayout == "NHWC")
131 {
132 SetupSingleInputSingleOutput({ 1u, 1u, 3u, 3u }, "graphInput", "potato");
133 }
134 else
135 {
136 SetupSingleInputSingleOutput({ 1u, 3u, 1u, 3u }, "graphInput", "potato");
137 }
surmeh01bceff2f2018-03-29 16:29:27 +0100138 }
139};
140
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000141struct DepthwiseConvolution2dNhwcSameFixture : DepthwiseConvolution2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100142{
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000143 DepthwiseConvolution2dNhwcSameFixture() : DepthwiseConvolution2dFixture("NHWC", "SAME") { }
surmeh01bceff2f2018-03-29 16:29:27 +0100144};
145
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000146BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DNhwcSame, DepthwiseConvolution2dNhwcSameFixture)
surmeh01bceff2f2018-03-29 16:29:27 +0100147{
148 RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 },
149 { 2.5f, 5.f, 2.5f, 3.5f, 7.f, 3.5f, 4.5f, 9.f, 4.5f,
150 6.f, 12.f, 6.f, 7.5f, 15.f, 7.5f, 9.f, 18.f, 9.f,
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000151 5.5f, 11.f, 5.5f, 6.5f, 13.f, 6.5f, 7.5f, 15.f, 7.5f });
surmeh01bceff2f2018-03-29 16:29:27 +0100152}
153
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000154struct DepthwiseConvolution2dNchwSameFixture : DepthwiseConvolution2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100155{
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000156 DepthwiseConvolution2dNchwSameFixture() : DepthwiseConvolution2dFixture("NCHW", "SAME") { }
surmeh01bceff2f2018-03-29 16:29:27 +0100157};
158
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000159BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DNchwSame, DepthwiseConvolution2dNchwSameFixture)
160{
161 RunTest<4>({ 1, 4, 7, 2, 5, 8, 3, 6, 9 },
162 { 2.5f, 6.f, 5.5f, 5.f, 12.f, 11.f, 2.5f, 6.f, 5.5f,
163 3.5f, 7.5f, 6.5f, 7.f, 15.f, 13.f, 3.5f, 7.5f, 6.5f,
164 4.5f, 9.f, 7.5f, 9.f, 18.f, 15.f, 4.5f, 9.f, 7.5f });
165}
166
167struct DepthwiseConvolution2dNhwcValidFixture : DepthwiseConvolution2dFixture
168{
169 DepthwiseConvolution2dNhwcValidFixture() : DepthwiseConvolution2dFixture("NHWC", "VALID") { }
170};
171
172BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DNhwcValid, DepthwiseConvolution2dNhwcValidFixture)
surmeh01bceff2f2018-03-29 16:29:27 +0100173{
174 RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // input data
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000175 { 6.f, 12.f, 6.f, 7.5f, 15.f, 7.5f, 9.f, 18.f, 9.f }); // output expected data
176}
177
178struct DepthwiseConvolution2dNchwValidFixture : DepthwiseConvolution2dFixture
179{
180 DepthwiseConvolution2dNchwValidFixture() : DepthwiseConvolution2dFixture("NCHW", "VALID") { }
181};
182
183BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DNchwValid, DepthwiseConvolution2dNchwValidFixture)
184{
185 RunTest<4>({ 1, 4, 7, 2, 5, 8, 3, 6, 9 },
186 { 6.f, 12.f, 6.f, 7.5f, 15.f, 7.5f, 9.f, 18.f, 9.f });
surmeh01bceff2f2018-03-29 16:29:27 +0100187}
188
189
190BOOST_AUTO_TEST_SUITE_END()