blob: 464e62fc23547c266408468e6996c79a70ecec23 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh01bceff2f2018-03-29 16:29:27 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9#include <string>
10#include <iostream>
11
Ferran Balaguer6a669d72018-12-11 10:29:05 +000012#include <Permute.hpp>
13using namespace armnnUtils;
14using namespace armnn;
15
surmeh01bceff2f2018-03-29 16:29:27 +010016BOOST_AUTO_TEST_SUITE(TensorflowParser)
17
telsoa01c577f2c2018-08-31 09:22:23 +010018struct DepthwiseConvolution2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010019{
Ferran Balaguer6a669d72018-12-11 10:29:05 +000020 explicit DepthwiseConvolution2dFixture(const std::string& dataLayout, const char* paddingType)
surmeh01bceff2f2018-03-29 16:29:27 +010021 {
22 m_Prototext = "node { \n"
23 " name: \"graphInput\" \n"
24 " op: \"Placeholder\" \n"
25 " attr { \n"
26 " key: \"dtype\" \n"
27 " value { \n"
28 " type: DT_FLOAT \n"
29 " } \n"
30 " } \n"
31 " attr { \n"
Ferran Balaguer6a669d72018-12-11 10:29:05 +000032 " key: \"shape\" \n"
surmeh01bceff2f2018-03-29 16:29:27 +010033 " value { \n"
Ferran Balaguer6a669d72018-12-11 10:29:05 +000034 " shape { \n"
surmeh01bceff2f2018-03-29 16:29:27 +010035 " } \n"
36 " } \n"
37 " } \n"
38 " } \n"
39 " node { \n"
40 " name: \"Const_1\" \n"
41 " op: \"Const\" \n"
42 " attr { \n"
43 " key: \"dtype\" \n"
44 " value { \n"
45 " type: DT_FLOAT \n"
46 " } \n"
47 " } \n"
48 " attr { \n"
49 " key: \"value\" \n"
50 " value { \n"
51 " tensor { \n"
52 " dtype: DT_FLOAT \n"
53 " tensor_shape { \n"
54 " dim { \n"
55 " size: 1 \n"
56 " } \n"
57 " dim { \n"
58 " size: 3 \n"
59 " } \n"
60 " dim { \n"
61 " size: 3 \n"
62 " } \n"
63 " dim { \n"
64 " size: 3 \n"
65 " } \n"
66 " } \n"
67 " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
68 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
69 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
70 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
71 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
72 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
73 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
74 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
75 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n"
76 " } \n"
77 " } \n"
78 " } \n"
79 "} \n"
80 "node { \n"
81 " name: \"potato\" \n"
82 " op: \"DepthwiseConv2dNative\" \n"
83 " input: \"graphInput\" \n"
84 " input: \"Const_1\" \n"
85 " attr { \n"
86 " key: \"T\" \n"
87 " value { \n"
88 " type: DT_FLOAT \n"
89 " } \n"
90 " } \n"
91 " attr { \n"
92 " key: \"data_format\" \n"
93 " value { \n"
Ferran Balaguer6a669d72018-12-11 10:29:05 +000094 " s: \"";
95 m_Prototext.append(dataLayout);
96 m_Prototext.append("\"\n"
surmeh01bceff2f2018-03-29 16:29:27 +010097 " } \n"
98 " } \n"
99 " attr { \n"
100 " key: \"padding\" \n"
101 " value { \n"
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000102 " s: \"");
surmeh01bceff2f2018-03-29 16:29:27 +0100103 m_Prototext.append(paddingType);
104 m_Prototext.append("\"\n"
105 " } \n"
106 " } \n"
107 " attr { \n"
108 " key: \"strides\" \n"
109 " value { \n"
110 " list { \n"
111 " i: 1 \n"
112 " i: 1 \n"
113 " i: 1 \n"
114 " i: 1 \n"
115 " } \n"
116 " } \n"
117 " } \n"
118 " attr { \n"
119 " key: \"use_cudnn_on_gpu\" \n"
120 " value { \n"
121 " b: false \n"
122 " } \n"
123 " } \n"
124 "} \n");
125
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000126 if(dataLayout == "NHWC")
127 {
128 SetupSingleInputSingleOutput({ 1u, 1u, 3u, 3u }, "graphInput", "potato");
129 }
130 else
131 {
132 SetupSingleInputSingleOutput({ 1u, 3u, 1u, 3u }, "graphInput", "potato");
133 }
surmeh01bceff2f2018-03-29 16:29:27 +0100134 }
135};
136
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000137struct DepthwiseConvolution2dNhwcSameFixture : DepthwiseConvolution2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100138{
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000139 DepthwiseConvolution2dNhwcSameFixture() : DepthwiseConvolution2dFixture("NHWC", "SAME") { }
surmeh01bceff2f2018-03-29 16:29:27 +0100140};
141
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000142BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DNhwcSame, DepthwiseConvolution2dNhwcSameFixture)
surmeh01bceff2f2018-03-29 16:29:27 +0100143{
144 RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 },
145 { 2.5f, 5.f, 2.5f, 3.5f, 7.f, 3.5f, 4.5f, 9.f, 4.5f,
146 6.f, 12.f, 6.f, 7.5f, 15.f, 7.5f, 9.f, 18.f, 9.f,
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000147 5.5f, 11.f, 5.5f, 6.5f, 13.f, 6.5f, 7.5f, 15.f, 7.5f });
surmeh01bceff2f2018-03-29 16:29:27 +0100148}
149
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000150struct DepthwiseConvolution2dNchwSameFixture : DepthwiseConvolution2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100151{
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000152 DepthwiseConvolution2dNchwSameFixture() : DepthwiseConvolution2dFixture("NCHW", "SAME") { }
surmeh01bceff2f2018-03-29 16:29:27 +0100153};
154
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000155BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DNchwSame, DepthwiseConvolution2dNchwSameFixture)
156{
157 RunTest<4>({ 1, 4, 7, 2, 5, 8, 3, 6, 9 },
158 { 2.5f, 6.f, 5.5f, 5.f, 12.f, 11.f, 2.5f, 6.f, 5.5f,
159 3.5f, 7.5f, 6.5f, 7.f, 15.f, 13.f, 3.5f, 7.5f, 6.5f,
160 4.5f, 9.f, 7.5f, 9.f, 18.f, 15.f, 4.5f, 9.f, 7.5f });
161}
162
163struct DepthwiseConvolution2dNhwcValidFixture : DepthwiseConvolution2dFixture
164{
165 DepthwiseConvolution2dNhwcValidFixture() : DepthwiseConvolution2dFixture("NHWC", "VALID") { }
166};
167
168BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DNhwcValid, DepthwiseConvolution2dNhwcValidFixture)
surmeh01bceff2f2018-03-29 16:29:27 +0100169{
170 RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // input data
Ferran Balaguer6a669d72018-12-11 10:29:05 +0000171 { 6.f, 12.f, 6.f, 7.5f, 15.f, 7.5f, 9.f, 18.f, 9.f }); // output expected data
172}
173
174struct DepthwiseConvolution2dNchwValidFixture : DepthwiseConvolution2dFixture
175{
176 DepthwiseConvolution2dNchwValidFixture() : DepthwiseConvolution2dFixture("NCHW", "VALID") { }
177};
178
179BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DNchwValid, DepthwiseConvolution2dNchwValidFixture)
180{
181 RunTest<4>({ 1, 4, 7, 2, 5, 8, 3, 6, 9 },
182 { 6.f, 12.f, 6.f, 7.5f, 15.f, 7.5f, 9.f, 18.f, 9.f });
surmeh01bceff2f2018-03-29 16:29:27 +0100183}
184
185
186BOOST_AUTO_TEST_SUITE_END()