blob: 05c500339909068da667aef6d5f5117a5bd7608b [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12struct UnsupportedMaximumFixture
13 : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
14{
15 UnsupportedMaximumFixture()
16 {
17 m_Prototext = R"(
18 node {
19 name: "graphInput"
20 op: "Placeholder"
21 attr {
22 key: "dtype"
23 value {
24 type: DT_FLOAT
25 }
26 }
27 attr {
28 key: "shape"
29 value {
30 shape {
31 }
32 }
33 }
34 }
35 node {
36 name: "Maximum"
37 op: "Maximum"
38 input: "graphInput"
telsoa01c577f2c2018-08-31 09:22:23 +010039 attr {
40 key: "dtype"
41 value {
42 type: DT_FLOAT
43 }
44 }
45 }
46 )";
47 }
48};
49
50BOOST_FIXTURE_TEST_CASE(UnsupportedMaximum, UnsupportedMaximumFixture)
51{
52 BOOST_CHECK_THROW(
53 SetupSingleInputSingleOutput({ 1, 1 }, "graphInput", "Maximum"),
54 armnn::ParseException);
55}
56
57struct SupportedMaximumFixture
58 : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
59{
60 SupportedMaximumFixture(const std::string & maxInput0,
61 const std::string & maxInput1,
62 const std::string & mulInput0,
63 const std::string & mulInput1)
64 {
65 m_Prototext = R"(
66 node {
67 name: "graphInput"
68 op: "Placeholder"
69 attr {
70 key: "dtype"
71 value { type: DT_FLOAT }
72 }
73 attr {
74 key: "shape"
75 value { shape { } }
76 }
77 }
78 node {
79 name: "Alpha"
80 op: "Const"
81 attr {
82 key: "dtype"
83 value { type: DT_FLOAT }
84 }
85 attr {
86 key: "value"
87 value {
88 tensor {
89 dtype: DT_FLOAT
90 tensor_shape {
91 dim { size: 1 }
92 }
93 float_val: 0.1
94 }
95 }
96 }
97 }
98 node {
99 name: "Mul"
100 op: "Mul"
101 input: ")" + mulInput0 + R"("
102 input: ")" + mulInput1 + R"("
103 attr {
104 key: "T"
105 value { type: DT_FLOAT }
106 }
107 }
108 node {
109 name: "Maximum"
110 op: "Maximum"
111 input: ")" + maxInput0 + R"("
112 input: ")" + maxInput1 + R"("
113 attr {
114 key: "T"
115 value { type: DT_FLOAT }
116 }
117 }
118 )";
119 SetupSingleInputSingleOutput({ 1, 2 }, "graphInput", "Maximum");
120 }
121};
122
123struct LeakyRelu_Max_MulAT_T_Fixture : public SupportedMaximumFixture
124{
125 LeakyRelu_Max_MulAT_T_Fixture()
126 : SupportedMaximumFixture("Mul","graphInput","Alpha","graphInput") {}
127};
128
129BOOST_FIXTURE_TEST_CASE(LeakyRelu_Max_MulAT_T, LeakyRelu_Max_MulAT_T_Fixture)
130{
131 RunTest<2>(std::vector<float>({-5.0, 3.0}), {-0.5, 3.0});
132}
133
134struct LeakyRelu_Max_T_MulAT_Fixture : public SupportedMaximumFixture
135{
136 LeakyRelu_Max_T_MulAT_Fixture()
137 : SupportedMaximumFixture("graphInput","Mul","Alpha","graphInput") {}
138};
139
140
141BOOST_FIXTURE_TEST_CASE(LeakyRelu_Max_T_MulAT, LeakyRelu_Max_T_MulAT_Fixture)
142{
143 RunTest<2>(std::vector<float>({-10.0, 3.0}), {-1.0, 3.0});
144}
145
146struct LeakyRelu_Max_MulTA_T_Fixture : public SupportedMaximumFixture
147{
148 LeakyRelu_Max_MulTA_T_Fixture()
149 : SupportedMaximumFixture("Mul", "graphInput","graphInput","Alpha") {}
150};
151
152BOOST_FIXTURE_TEST_CASE(LeakyRelu_Max_MulTA_T, LeakyRelu_Max_MulTA_T_Fixture)
153{
154 RunTest<2>(std::vector<float>({-5.0, 3.0}), {-0.5, 3.0});
155}
156
157struct LeakyRelu_Max_T_MulTA_Fixture : public SupportedMaximumFixture
158{
159 LeakyRelu_Max_T_MulTA_Fixture()
160 : SupportedMaximumFixture("graphInput", "Mul", "graphInput", "Alpha") {}
161};
162
163BOOST_FIXTURE_TEST_CASE(LeakyRelu_Max_T_MulTA, LeakyRelu_Max_T_MulTA_Fixture)
164{
165 RunTest<2>(std::vector<float>({-10.0, 13.0}), {-1.0, 13.0});
166}
167
168BOOST_AUTO_TEST_SUITE_END()