blob: ca9c416ca58027c0b7e04023037cff2006ba938d [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
telsoa01c577f2c2018-08-31 09:22:23 +010012struct MultiplicationFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010013{
14 MultiplicationFixture()
15 {
16 m_Prototext = "node { \n"
17 " name: \"graphInput\" \n"
18 " op: \"Placeholder\" \n"
19 " attr { \n"
20 " key: \"dtype\" \n"
21 " value { \n"
22 " type: DT_FLOAT \n"
23 " } \n"
24 " } \n"
25 " attr { \n"
26 " key: \"shape\" \n"
27 " value { \n"
28 " shape { \n"
29 " } \n"
30 " } \n"
31 " } \n"
32 " } \n"
33 " node { \n"
34 " name: \"softmax1\" \n"
35 " op: \"Softmax\" \n"
36 " input: \"graphInput\" \n"
37 " attr { \n"
38 " key: \"T\" \n"
39 " value { \n"
40 " type: DT_FLOAT \n"
41 " } \n"
42 " } \n"
43 " }\n"
44 " node {\n"
45 " name: \"softmax2\"\n"
46 " op : \"Softmax\"\n"
47 " input: \"graphInput\"\n"
48 " attr { \n"
49 " key: \"T\" \n"
50 " value { \n"
51 " type: DT_FLOAT \n"
52 " } \n"
53 " } \n"
54 " }\n"
55 " node {\n"
56 " name: \"multiplication\"\n"
57 " op : \"Mul\"\n"
58 " input: \"softmax1\"\n"
59 " input: \"softmax2\"\n"
60 " attr { \n"
61 " key: \"T\" \n"
62 " value { \n"
63 " type: DT_FLOAT \n"
64 " } \n"
65 " } \n"
66 " }\n";
67
68 SetupSingleInputSingleOutput({ 1, 7 }, "graphInput", "multiplication");
69 }
70};
71
72BOOST_FIXTURE_TEST_CASE(ParseMultiplication, MultiplicationFixture)
73{
74 RunTest<2>({ 0, 0, 10000, 0, 0, 0, 0 }, { 0, 0, 1, 0, 0, 0, 0 });
75}
76
telsoa01c577f2c2018-08-31 09:22:23 +010077struct MultiplicationBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010078{
79 MultiplicationBroadcastFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1)
80 {
81 m_Prototext = R"(
82node {
83 name: "input0"
84 op: "Placeholder"
85 attr {
86 key: "dtype"
87 value {
88 type: DT_FLOAT
89 }
90 }
91 attr {
92 key: "shape"
93 value {
94 shape {
95 }
96 }
97 }
98}
99node {
100 name: "input1"
101 op: "Placeholder"
102 attr {
103 key: "dtype"
104 value {
105 type: DT_FLOAT
106 }
107 }
108 attr {
109 key: "shape"
110 value {
111 shape {
112 }
113 }
114 }
115}
116node {
117 name: "output"
118 op: "Mul"
119 input: "input0"
120 input: "input1"
121 attr {
122 key: "T"
123 value {
124 type: DT_FLOAT
125 }
126 }
127}
128 )";
129
130 Setup({ { "input0", inputShape0 },
131 { "input1", inputShape1 } },
132 { "output" });
133 }
134};
135
136struct MultiplicationBroadcastFixture4D1D : public MultiplicationBroadcastFixture
137{
138 MultiplicationBroadcastFixture4D1D() : MultiplicationBroadcastFixture({ 1, 2, 2, 3 }, { 1 }) {}
139};
140
141BOOST_FIXTURE_TEST_CASE(ParseMultiplicationBroadcast4D1D, MultiplicationBroadcastFixture4D1D)
142{
143 RunTest<4>({ { "input0", { 0.0f, 1.0f, 2.0f,
144 3.0f, 4.0f, 5.0f,
145 6.0f, 7.0f, 8.0f,
146 9.0f, 10.0f, 11.0f } },
147 { "input1", { 5.0f } } },
148 { { "output", { 0.0f, 5.0f, 10.0f,
149 15.0f, 20.0f, 25.0f,
150 30.0f, 35.0f, 40.0f,
151 45.0f, 50.0f, 55.0f } } });
152}
153
154struct MultiplicationBroadcastFixture1D4D : public MultiplicationBroadcastFixture
155{
156 MultiplicationBroadcastFixture1D4D() : MultiplicationBroadcastFixture({ 1 }, { 1, 2, 2, 3 }) {}
157};
158
159BOOST_FIXTURE_TEST_CASE(ParseMultiplicationBroadcast1D4D, MultiplicationBroadcastFixture1D4D)
160{
161 RunTest<4>({ { "input0", { 3.0f } },
162 { "input1", { 0.0f, 1.0f, 2.0f,
163 3.0f, 4.0f, 5.0f,
164 6.0f, 7.0f, 8.0f,
165 9.0f, 10.0f, 11.0f } } },
166 { { "output", { 0.0f, 3.0f, 6.0f,
167 9.0f, 12.0f, 15.0f,
168 18.0f, 21.0f, 24.0f,
169 27.0f, 30.0f, 33.0f } } });
170}
171
172BOOST_AUTO_TEST_SUITE_END()