blob: a2566fced576580be6fc0f4f40e61d9667936302 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12struct UnsupportedMaximumFixture
13 : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
14{
15 UnsupportedMaximumFixture()
16 {
17 m_Prototext = R"(
18 node {
19 name: "graphInput"
20 op: "Placeholder"
21 attr {
22 key: "dtype"
23 value {
24 type: DT_FLOAT
25 }
26 }
27 attr {
28 key: "shape"
29 value {
30 shape {
31 }
32 }
33 }
34 }
35 node {
36 name: "Maximum"
37 op: "Maximum"
38 input: "graphInput"
39 input: "graphInput"
40 attr {
41 key: "dtype"
42 value {
43 type: DT_FLOAT
44 }
45 }
46 }
47 )";
48 }
49};
50
51BOOST_FIXTURE_TEST_CASE(UnsupportedMaximum, UnsupportedMaximumFixture)
52{
53 BOOST_CHECK_THROW(
54 SetupSingleInputSingleOutput({ 1, 1 }, "graphInput", "Maximum"),
55 armnn::ParseException);
56}
57
58struct SupportedMaximumFixture
59 : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
60{
61 SupportedMaximumFixture(const std::string & maxInput0,
62 const std::string & maxInput1,
63 const std::string & mulInput0,
64 const std::string & mulInput1)
65 {
66 m_Prototext = R"(
67 node {
68 name: "graphInput"
69 op: "Placeholder"
70 attr {
71 key: "dtype"
72 value { type: DT_FLOAT }
73 }
74 attr {
75 key: "shape"
76 value { shape { } }
77 }
78 }
79 node {
80 name: "Alpha"
81 op: "Const"
82 attr {
83 key: "dtype"
84 value { type: DT_FLOAT }
85 }
86 attr {
87 key: "value"
88 value {
89 tensor {
90 dtype: DT_FLOAT
91 tensor_shape {
92 dim { size: 1 }
93 }
94 float_val: 0.1
95 }
96 }
97 }
98 }
99 node {
100 name: "Mul"
101 op: "Mul"
102 input: ")" + mulInput0 + R"("
103 input: ")" + mulInput1 + R"("
104 attr {
105 key: "T"
106 value { type: DT_FLOAT }
107 }
108 }
109 node {
110 name: "Maximum"
111 op: "Maximum"
112 input: ")" + maxInput0 + R"("
113 input: ")" + maxInput1 + R"("
114 attr {
115 key: "T"
116 value { type: DT_FLOAT }
117 }
118 }
119 )";
120 SetupSingleInputSingleOutput({ 1, 2 }, "graphInput", "Maximum");
121 }
122};
123
124struct LeakyRelu_Max_MulAT_T_Fixture : public SupportedMaximumFixture
125{
126 LeakyRelu_Max_MulAT_T_Fixture()
127 : SupportedMaximumFixture("Mul","graphInput","Alpha","graphInput") {}
128};
129
130BOOST_FIXTURE_TEST_CASE(LeakyRelu_Max_MulAT_T, LeakyRelu_Max_MulAT_T_Fixture)
131{
132 RunTest<2>(std::vector<float>({-5.0, 3.0}), {-0.5, 3.0});
133}
134
135struct LeakyRelu_Max_T_MulAT_Fixture : public SupportedMaximumFixture
136{
137 LeakyRelu_Max_T_MulAT_Fixture()
138 : SupportedMaximumFixture("graphInput","Mul","Alpha","graphInput") {}
139};
140
141
142BOOST_FIXTURE_TEST_CASE(LeakyRelu_Max_T_MulAT, LeakyRelu_Max_T_MulAT_Fixture)
143{
144 RunTest<2>(std::vector<float>({-10.0, 3.0}), {-1.0, 3.0});
145}
146
147struct LeakyRelu_Max_MulTA_T_Fixture : public SupportedMaximumFixture
148{
149 LeakyRelu_Max_MulTA_T_Fixture()
150 : SupportedMaximumFixture("Mul", "graphInput","graphInput","Alpha") {}
151};
152
153BOOST_FIXTURE_TEST_CASE(LeakyRelu_Max_MulTA_T, LeakyRelu_Max_MulTA_T_Fixture)
154{
155 RunTest<2>(std::vector<float>({-5.0, 3.0}), {-0.5, 3.0});
156}
157
158struct LeakyRelu_Max_T_MulTA_Fixture : public SupportedMaximumFixture
159{
160 LeakyRelu_Max_T_MulTA_Fixture()
161 : SupportedMaximumFixture("graphInput", "Mul", "graphInput", "Alpha") {}
162};
163
164BOOST_FIXTURE_TEST_CASE(LeakyRelu_Max_T_MulTA, LeakyRelu_Max_T_MulTA_Fixture)
165{
166 RunTest<2>(std::vector<float>({-10.0, 13.0}), {-1.0, 13.0});
167}
168
169BOOST_AUTO_TEST_SUITE_END()