blob: d73682961f04ef3fcb3ba7f417837987db2beb8a [file] [log] [blame]
Ferran Balaguer51dd62f2019-01-11 19:29:18 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Ferran Balaguer51dd62f2019-01-11 19:29:18 +00006#include "armnnTfParser/ITfParser.hpp"
Matthew Bentham4057d912019-01-21 15:45:51 +00007
8#include <ParserPrototxtFixture.hpp>
9#include <PrototxtConversions.hpp>
10
11#include <boost/test/unit_test.hpp>
Ferran Balaguer51dd62f2019-01-11 19:29:18 +000012
13BOOST_AUTO_TEST_SUITE(TensorflowParser)
14
15struct MeanFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
16{
17 explicit MeanFixture(const armnn::TensorShape& inputShape, const armnn::TensorShape& outputShape,
18 const std::vector<unsigned int>& axis, bool keepDims)
19 {
20 std::string protobufAxisString;
21 std::vector<unsigned int> protobufAxis(axis);
22
23 // If no axis range is specified, the reduction is applied to
24 // all dimensions of the input tensor
25 if (protobufAxis.size() == 0)
26 {
27 for (unsigned int i = 0; i < inputShape.GetNumDimensions(); ++i)
28 {
29 protobufAxis.push_back(i);
30 }
31 }
32
33 for (unsigned int i = 0; i < protobufAxis.size(); ++i)
34 {
Matthew Bentham4057d912019-01-21 15:45:51 +000035 protobufAxisString.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(protobufAxis[i])));
Ferran Balaguer51dd62f2019-01-11 19:29:18 +000036 }
37
38 m_Prototext = R"(node {
39 name: "input"
40 op: "Placeholder"
41 attr {
42 key: "dtype"
43 value {
44 type: DT_FLOAT
45 }
46 }
47 attr {
48 key: "shape"
49 value {
50 shape {
51 }
52 }
53 }
54 }
55 node {
56 name: "Const"
57 op: "Const"
58 attr {
59 key: "dtype"
60 value {
61 type: DT_INT32
62 }
63 }
64 attr {
65 key: "value"
66 value { )";
67
68 if (axis.size() == 1)
69 {
70 m_Prototext.append(R"( tensor {
71 dtype: DT_INT32
72 tensor_shape {
73 }
74 int_val: )").append(std::to_string(protobufAxis[0])).append(R"(
75 } )");
76 }
77 else
78 {
79 m_Prototext.append(R"( tensor {
80 dtype: DT_INT32
81 tensor_shape {
82 dim {
83 size: 2
84 }
85 }
86 tensor_content: ")").append(protobufAxisString).append(R"("
87 } )");
88 }
89
90 m_Prototext.append(R"( }
91 }
92 }
93 node {
94 name: "output"
95 op: "Mean"
96 input: "input"
97 input: "Const"
98 attr {
99 key: "T"
100 value {
101 type: DT_FLOAT
102 }
103 }
104 attr {
105 key: "Tidx"
106 value {
107 type: DT_INT32
108 }
109 }
110 attr {
111 key: "keep_dims"
112 value {
113 b: )").append(keepDims ? "true" : "false").append(R"(
114 }
115 }
116 })");
117
118 SetupSingleInputSingleOutput(inputShape, outputShape, "input", "output");
119 }
120};
121
122struct MeanNoAxisNoKeepDimsFixture: MeanFixture
123{
124 MeanNoAxisNoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 1 }, {}, false) {}
125};
126
127struct MeanWithAxis0NoKeepDimsFixture: MeanFixture
128{
129 MeanWithAxis0NoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 3 }, { 0 }, false) {}
130};
131
132struct MeanWithAxis1NoKeepDimsFixture: MeanFixture
133{
134 MeanWithAxis1NoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 2 }, { 1 }, false) {}
135};
136
137struct MeanWithAxis0KeepDimsFixture: MeanFixture
138{
139 MeanWithAxis0KeepDimsFixture() : MeanFixture({ 2, 3 }, { 1, 3 }, { 0 }, true) {}
140};
141
142struct MeanWithAxis1KeepDimsFixture: MeanFixture
143{
144 MeanWithAxis1KeepDimsFixture() : MeanFixture({ 2, 3 }, { 2, 1 }, { 1 }, true) {}
145};
146
147
148BOOST_FIXTURE_TEST_CASE(MeanNoAxisNoKeepDims, MeanNoAxisNoKeepDimsFixture)
149{
150 RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
151 { { "output", { 1.5f } } });
152}
153
154BOOST_FIXTURE_TEST_CASE(MeanWithAxis0NoKeepDims, MeanWithAxis0NoKeepDimsFixture)
155{
156 RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
157 { { "output", { 1.5f, 1.5f, 1.5f } } });
158}
159
160BOOST_FIXTURE_TEST_CASE(MeanWithAxis1NoKeepDims, MeanWithAxis1NoKeepDimsFixture)
161{
162 RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
163 { { "output", { 1.f, 2.f } } });
164}
165
166BOOST_FIXTURE_TEST_CASE(MeanWithAxis0KeepDims, MeanWithAxis0KeepDimsFixture)
167{
168 RunTest<2>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
169 { { "output", { 1.5f, 1.5f, 1.5f } } });
170}
171
172BOOST_FIXTURE_TEST_CASE(MeanWithAxis1KeepDims, MeanWithAxis1KeepDimsFixture)
173{
174 RunTest<2>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
175 { { "output", { 1.f, 2.f } } });
176}
177
178BOOST_AUTO_TEST_SUITE_END()