blob: dd73bd90a29adf8710cd7c715a14359d4787e922 [file] [log] [blame]
Sang-Hoon Parkdd3f71b2020-02-18 11:27:35 +00001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "armnnTfParser/ITfParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8
9#include <boost/test/unit_test.hpp>
10#include <PrototxtConversions.hpp>
11
12BOOST_AUTO_TEST_SUITE(TensorflowParser)
13
14namespace
15{
16 std::string ConvertInt32VectorToOctalString(const std::vector<unsigned int>& data)
17 {
18 std::stringstream ss;
19 ss << "\"";
20 std::for_each(data.begin(), data.end(), [&ss](unsigned int d) {
21 ss << armnnUtils::ConvertInt32ToOctalString(static_cast<int>(d));
22 });
23 ss << "\"";
24 return ss.str();
25 }
26} // namespace
27
28struct TransposeFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
29{
30 TransposeFixture(const armnn::TensorShape& inputShape,
31 const std::vector<unsigned int>& permuteVectorData)
32 {
33 using armnnUtils::ConvertTensorShapeToString;
34 armnn::TensorShape permuteVectorShape({static_cast<unsigned int>(permuteVectorData.size())});
35
36 m_Prototext = "node {\n"
37" name: \"input\"\n"
38" op: \"Placeholder\"\n"
39" attr {\n"
40" key: \"dtype\"\n"
41" value {\n"
42" type: DT_FLOAT\n"
43" }\n"
44" }\n"
45" attr {\n"
46" key: \"shape\"\n"
47" value {\n"
48" shape {\n";
49 m_Prototext.append(ConvertTensorShapeToString(inputShape));
50 m_Prototext.append(
51" }\n"
52" }\n"
53" }\n"
54"}\n"
55"node {\n"
56" name: \"transpose/perm\"\n"
57" op: \"Const\"\n"
58" attr {\n"
59" key: \"dtype\"\n"
60" value {\n"
61" type: DT_INT32\n"
62" }\n"
63" }\n"
64" attr {\n"
65" key: \"value\"\n"
66" value {\n"
67" tensor {\n"
68" dtype: DT_INT32\n"
69" tensor_shape {\n"
70 );
71 m_Prototext.append(ConvertTensorShapeToString(permuteVectorShape));
72 m_Prototext.append(
73" }\n"
74" tensor_content: "
75 );
76 m_Prototext.append(ConvertInt32VectorToOctalString(permuteVectorData) + "\n");
77 m_Prototext.append(
78" }\n"
79" }\n"
80" }\n"
81"}\n"
82 );
83 m_Prototext.append(
84"node {\n"
85" name: \"output\"\n"
86" op: \"Transpose\"\n"
87" input: \"input\"\n"
88" input: \"transpose/perm\"\n"
89" attr {\n"
90" key: \"T\"\n"
91" value {\n"
92" type: DT_FLOAT\n"
93" }\n"
94" }\n"
95" attr {\n"
96" key: \"Tperm\"\n"
97" value {\n"
98" type: DT_INT32\n"
99" }\n"
100" }\n"
101"}\n"
102 );
103 Setup({{"input", inputShape}}, {"output"});
104 }
105};
106
107struct TransposeFixtureWithPermuteData : TransposeFixture
108{
109 TransposeFixtureWithPermuteData()
110 : TransposeFixture({2, 2, 3, 4},
111 std::vector<unsigned int>({1, 3, 2, 0})) {}
112};
113
114BOOST_FIXTURE_TEST_CASE(TransposeWithPermuteData, TransposeFixtureWithPermuteData)
115{
116 RunTest<4>(
117 {{"input", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
118 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
119 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}}},
120 {{"output", {0, 24, 4, 28, 8, 32, 1, 25, 5, 29, 9, 33, 2, 26, 6,
121 30, 10, 34, 3, 27, 7, 31, 11, 35, 12, 36, 16, 40, 20, 44, 13, 37,
122 17, 41, 21, 45, 14, 38, 18, 42, 22, 46, 15, 39, 19, 43, 23, 47}}});
123
124 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("output").second.GetShape()
125 == armnn::TensorShape({2, 4, 3, 2})));
126}
127
128struct TransposeFixtureWithoutPermuteData : TransposeFixture
129{
130 // In case permute data is not given, it assumes (n-1,...,0) is given
131 // where n is the rank of input tensor.
132 TransposeFixtureWithoutPermuteData()
133 : TransposeFixture({2, 2, 3, 4},
134 std::vector<unsigned int>({3, 2, 1, 0})) {}
135};
136
137BOOST_FIXTURE_TEST_CASE(TransposeWithoutPermuteData, TransposeFixtureWithoutPermuteData)
138{
139 RunTest<4>(
140 {{"input", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
141 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
142 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}}},
143 {{"output", {0, 24, 12, 36, 4, 28, 16, 40, 8, 32, 20, 44, 1, 25,
144 13, 37, 5, 29, 17, 41, 9, 33, 21, 45, 2, 26, 14, 38, 6, 30, 18,
145 42,10, 34, 22, 46, 3, 27, 15, 39, 7, 31, 19, 43, 11, 35, 23, 47}}});
146
147 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("output").second.GetShape()
148 == armnn::TensorShape({4, 3, 2, 2})));
149}
150
151BOOST_AUTO_TEST_SUITE_END()