blob: 16b1124e24a1df734f6d0a12391feeaa9bf011d7 [file] [log] [blame]
Ferran Balaguerfbdad032018-12-28 18:15:24 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01006#include <armnn/utility/Assert.hpp>
Ferran Balaguerfbdad032018-12-28 18:15:24 +00007#include <boost/test/unit_test.hpp>
8
9#include "armnnTfParser/ITfParser.hpp"
10#include "ParserPrototxtFixture.hpp"
11
12#include <map>
13#include <string>
14
15
16BOOST_AUTO_TEST_SUITE(TensorflowParser)
17
18struct AddNFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
19{
20 AddNFixture(const std::vector<armnn::TensorShape> inputShapes, unsigned int numberOfInputs)
21 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010022 ARMNN_ASSERT(inputShapes.size() == numberOfInputs);
Ferran Balaguerfbdad032018-12-28 18:15:24 +000023 m_Prototext = "";
24 for (unsigned int i = 0; i < numberOfInputs; i++)
25 {
26 m_Prototext.append("node { \n");
27 m_Prototext.append(" name: \"input").append(std::to_string(i)).append("\"\n");
28 m_Prototext += R"( op: "Placeholder"
29 attr {
30 key: "dtype"
31 value {
32 type: DT_FLOAT
33 }
34 }
35 attr {
36 key: "shape"
37 value {
38 shape {
39 }
40 }
41 }
42}
43)";
44 }
45 m_Prototext += R"(node {
46 name: "output"
47 op: "AddN"
48)";
49 for (unsigned int i = 0; i < numberOfInputs; i++)
50 {
51 m_Prototext.append(" input: \"input").append(std::to_string(i)).append("\"\n");
52 }
53 m_Prototext += R"( attr {
54 key: "N"
55 value {
56)";
57 m_Prototext.append(" i: ").append(std::to_string(numberOfInputs)).append("\n");
58 m_Prototext += R"( }
59 }
60 attr {
61 key: "T"
62 value {
63 type: DT_FLOAT
64 }
65 }
66})";
67
68 std::map<std::string, armnn::TensorShape> inputs;
69 for (unsigned int i = 0; i < numberOfInputs; i++)
70 {
71 std::string name("input");
72 name.append(std::to_string(i));
73 inputs.emplace(std::make_pair(name, inputShapes[i]));
74 }
75 Setup(inputs, {"output"});
76 }
77
78};
79
80// try with 2, 3, 5 and 8 inputs
81struct FiveTwoDimInputsFixture : AddNFixture
82{
83 FiveTwoDimInputsFixture() : AddNFixture({ { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 } }, 5) {}
84};
85
86
87BOOST_FIXTURE_TEST_CASE(FiveTwoDimInputs, FiveTwoDimInputsFixture)
88{
89 RunTest<2>({ { "input0", { 1.0, 2.0, 3.0, 4.0 } },
90 { "input1", { 1.0, 5.0, 2.0, 2.0 } },
91 { "input2", { 1.0, 1.0, 2.0, 2.0 } },
92 { "input3", { 3.0, 7.0, 1.0, 2.0 } },
93 { "input4", { 8.0, 0.0, -2.0, -3.0 } } },
94 { { "output", { 14.0, 15.0, 6.0, 7.0 } } });
95}
96
97struct TwoTwoDimInputsFixture : AddNFixture
98{
99 TwoTwoDimInputsFixture() : AddNFixture({ { 2, 2 }, { 2, 2 } }, 2) {}
100};
101
102BOOST_FIXTURE_TEST_CASE(TwoTwoDimInputs, TwoTwoDimInputsFixture)
103{
104 RunTest<2>({ { "input0", { 1.0, 2.0, 3.0, 4.0 } },
105 { "input1", { 1.0, 5.0, 2.0, 2.0 } } },
106 { { "output", { 2.0, 7.0, 5.0, 6.0 } } });
107}
108
109struct ThreeTwoDimInputsFixture : AddNFixture
110{
111 ThreeTwoDimInputsFixture() : AddNFixture({ { 2, 2 }, { 2, 2 }, { 2, 2 } }, 3) {}
112};
113
114BOOST_FIXTURE_TEST_CASE(ThreeTwoDimInputs, ThreeTwoDimInputsFixture)
115{
116 RunTest<2>({ { "input0", { 1.0, 2.0, 3.0, 4.0 } },
117 { "input1", { 1.0, 5.0, 2.0, 2.0 } },
118 { "input2", { 1.0, 1.0, 2.0, 2.0 } } },
119 { { "output", { 3.0, 8.0, 7.0, 8.0 } } });
120}
121
122struct EightTwoDimInputsFixture : AddNFixture
123{
124 EightTwoDimInputsFixture() : AddNFixture({ { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 },
125 { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 } }, 8) {}
126};
127
128BOOST_FIXTURE_TEST_CASE(EightTwoDimInputs, EightTwoDimInputsFixture)
129{
130 RunTest<2>({ { "input0", { 1.0, 2.0, 3.0, 4.0 } },
131 { "input1", { 1.0, 5.0, 2.0, 2.0 } },
132 { "input2", { 1.0, 1.0, 2.0, 2.0 } },
133 { "input3", { 3.0, 7.0, 1.0, 2.0 } },
134 { "input4", { 8.0, 0.0, -2.0, -3.0 } },
135 { "input5", {-3.0, 2.0, -1.0, -5.0 } },
136 { "input6", { 1.0, 6.0, 2.0, 2.0 } },
137 { "input7", {-19.0, 7.0, 1.0, -10.0 } } },
138 { { "output", {-7.0, 30.0, 8.0, -6.0 } } });
139}
140
141struct ThreeInputBroadcast1D4D4DInputsFixture : AddNFixture
142{
143 ThreeInputBroadcast1D4D4DInputsFixture() : AddNFixture({ { 1 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, 3) {}
144};
145
146BOOST_FIXTURE_TEST_CASE(ThreeInputBroadcast1D4D4DInputs, ThreeInputBroadcast1D4D4DInputsFixture)
147{
148 RunTest<4>({ { "input0", { 1.0 } },
149 { "input1", { 1.0, 5.0, 2.0, 2.0 } },
150 { "input2", { 1.0, 1.0, 2.0, 2.0 } } },
151 { { "output", { 3.0, 7.0, 5.0, 5.0 } } });
152}
153
154struct ThreeInputBroadcast4D1D4DInputsFixture : AddNFixture
155{
156 ThreeInputBroadcast4D1D4DInputsFixture() : AddNFixture({ { 1, 1, 2, 2 }, { 1 }, { 1, 1, 2, 2 } }, 3) {}
157};
158
159BOOST_FIXTURE_TEST_CASE(ThreeInputBroadcast4D1D4DInputs, ThreeInputBroadcast4D1D4DInputsFixture)
160{
161 RunTest<4>({ { "input0", { 1.0, 3.0, 9.0, 4.0 } },
162 { "input1", {-2.0 } },
163 { "input2", { 1.0, 1.0, 2.0, 2.0 } } },
164 { { "output", { 0.0, 2.0, 9.0, 4.0 } } });
165}
166
167struct ThreeInputBroadcast4D4D1DInputsFixture : AddNFixture
168{
169 ThreeInputBroadcast4D4D1DInputsFixture() : AddNFixture({ { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1 } }, 3) {}
170};
171
172BOOST_FIXTURE_TEST_CASE(ThreeInputBroadcast4D4D1DInputs, ThreeInputBroadcast4D4D1DInputsFixture)
173{
174 RunTest<4>({ { "input0", { 1.0, 5.0, 2.0, 2.0 } },
175 { "input1", { 1.0, 1.0, 2.0, 2.0 } },
176 { "input2", { 1.0 } } },
177 { { "output", { 3.0, 7.0, 5.0, 5.0 } } });
178}
179
180BOOST_AUTO_TEST_SUITE_END()