blob: 84e7a7e7a94fc22a6948ff50f825656271b1e531 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9#include <string>
10#include <iostream>
11
12BOOST_AUTO_TEST_SUITE(TensorflowParser)
13
14struct DepthwiseConvolution2dFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
15{
16 explicit DepthwiseConvolution2dFixture(const char* paddingType)
17 {
18 m_Prototext = "node { \n"
19 " name: \"graphInput\" \n"
20 " op: \"Placeholder\" \n"
21 " attr { \n"
22 " key: \"dtype\" \n"
23 " value { \n"
24 " type: DT_FLOAT \n"
25 " } \n"
26 " } \n"
27 " attr { \n"
28 " key: \"value\" \n"
29 " value { \n"
30 " tensor { \n"
31 " dtype: DT_FLOAT \n"
32 " tensor_shape { \n"
33 " dim { \n"
34 " size: 1 \n"
35 " } \n"
36 " dim { \n"
37 " size: 1 \n"
38 " } \n"
39 " dim { \n"
40 " size: 3 \n"
41 " } \n"
42 " dim { \n"
43 " size: 3 \n"
44 " } \n"
45 " } \n"
46 " tensor_content: \"\\000\\000\\200?\\000\\000\\000@\\000\\000@@\\000\\000\\200@"
47 "\\000\\000\\240@\\000\\000\\300@\\000\\000\\340@\\000\\000\\000A\\000\\000\\020A\" \n"
48 " } \n"
49 " } \n"
50 " } \n"
51 " } \n"
52 " node { \n"
53 " name: \"Const_1\" \n"
54 " op: \"Const\" \n"
55 " attr { \n"
56 " key: \"dtype\" \n"
57 " value { \n"
58 " type: DT_FLOAT \n"
59 " } \n"
60 " } \n"
61 " attr { \n"
62 " key: \"value\" \n"
63 " value { \n"
64 " tensor { \n"
65 " dtype: DT_FLOAT \n"
66 " tensor_shape { \n"
67 " dim { \n"
68 " size: 1 \n"
69 " } \n"
70 " dim { \n"
71 " size: 3 \n"
72 " } \n"
73 " dim { \n"
74 " size: 3 \n"
75 " } \n"
76 " dim { \n"
77 " size: 3 \n"
78 " } \n"
79 " } \n"
80 " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
81 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
82 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
83 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
84 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
85 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
86 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
87 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
88 "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n"
89 " } \n"
90 " } \n"
91 " } \n"
92 "} \n"
93 "node { \n"
94 " name: \"potato\" \n"
95 " op: \"DepthwiseConv2dNative\" \n"
96 " input: \"graphInput\" \n"
97 " input: \"Const_1\" \n"
98 " attr { \n"
99 " key: \"T\" \n"
100 " value { \n"
101 " type: DT_FLOAT \n"
102 " } \n"
103 " } \n"
104 " attr { \n"
105 " key: \"data_format\" \n"
106 " value { \n"
107 " s: \"NHWC\" \n"
108 " } \n"
109 " } \n"
110 " attr { \n"
111 " key: \"padding\" \n"
112 " value { \n"
113 " s: \"";
114 m_Prototext.append(paddingType);
115 m_Prototext.append("\"\n"
116 " } \n"
117 " } \n"
118 " attr { \n"
119 " key: \"strides\" \n"
120 " value { \n"
121 " list { \n"
122 " i: 1 \n"
123 " i: 1 \n"
124 " i: 1 \n"
125 " i: 1 \n"
126 " } \n"
127 " } \n"
128 " } \n"
129 " attr { \n"
130 " key: \"use_cudnn_on_gpu\" \n"
131 " value { \n"
132 " b: false \n"
133 " } \n"
134 " } \n"
135 "} \n");
136
137 SetupSingleInputSingleOutput({ 1, 1, 3, 3 }, "graphInput", "potato");
138 }
139};
140
141struct DepthwiseConvolution2dSameFixture : DepthwiseConvolution2dFixture
142{
143 DepthwiseConvolution2dSameFixture() : DepthwiseConvolution2dFixture("SAME") { }
144};
145
146BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DSame, DepthwiseConvolution2dSameFixture)
147{
148 RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 },
149 { 2.5f, 5.f, 2.5f, 3.5f, 7.f, 3.5f, 4.5f, 9.f, 4.5f,
150 6.f, 12.f, 6.f, 7.5f, 15.f, 7.5f, 9.f, 18.f, 9.f,
151 5.5f, 11.f, 5.5f, 6.5f, 13.f, 6.5f, 7.5f, 15.f, 7.5f});
152}
153
154struct DepthwiseConvolution2dValidFixture : DepthwiseConvolution2dFixture
155{
156 DepthwiseConvolution2dValidFixture() : DepthwiseConvolution2dFixture("VALID") { }
157};
158
159BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DValid, DepthwiseConvolution2dValidFixture)
160{
161 RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // input data
162 { 6.f, 12.f, 6.f, 7.5f, 15.f, 7.5f, 9.f, 18.f, 9.f }); // output expected data
163}
164
165
166BOOST_AUTO_TEST_SUITE_END()