blob: aab6dbfd79e8e1d56d1d624252c588c0a7dc1589 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01009// This is a special case for add, which supports broadcasting.
surmeh01bceff2f2018-03-29 16:29:27 +010010BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
telsoa01c577f2c2018-08-31 09:22:23 +010012struct BroadcastForAddFixtureSlot1 : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010013{
14 BroadcastForAddFixtureSlot1()
15 {
16 m_Prototext = R"(
17 node {
18 name: "graphInput"
19 op: "Placeholder"
20 attr {
21 key: "dtype"
22 value {
23 type: DT_FLOAT
24 }
25 }
26 attr {
27 key: "shape"
28 value {
29 shape {
30 }
31 }
32 }
33 }
34 node {
35 name: "Const_1"
36 op: "Const"
37 attr {
38 key: "dtype"
39 value {
40 type: DT_FLOAT
41 }
42 }
43 attr {
44 key: "value"
45 value {
46 tensor {
47 dtype: DT_FLOAT
48 tensor_shape {
49 }
50 float_val: 4.0
51 float_val: 5.0
52 }
53 }
54 }
55 }
56 node {
57 name: "Add"
58 op: "Add"
59 input: "graphInput"
60 input: "Const_1"
61 attr {
62 key: "T"
63 value {
64 type: DT_FLOAT
65 }
66 }
67 }
68 )";
69
70 SetupSingleInputSingleOutput({ 1, 2, 2, 2 }, "graphInput", "Add");
71 }
72};
73
telsoa01c577f2c2018-08-31 09:22:23 +010074struct BroadcastForAddFixtureSlot0 : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010075{
76 BroadcastForAddFixtureSlot0()
77 {
78 m_Prototext = R"(
79 node {
80 name: "graphInput"
81 op: "Placeholder"
82 attr {
83 key: "dtype"
84 value {
85 type: DT_FLOAT
86 }
87 }
88 attr {
89 key: "shape"
90 value {
91 shape {
92 }
93 }
94 }
95 }
96 node {
97 name: "Const_1"
98 op: "Const"
99 attr {
100 key: "dtype"
101 value {
102 type: DT_FLOAT
103 }
104 }
105 attr {
106 key: "value"
107 value {
108 tensor {
109 dtype: DT_FLOAT
110 tensor_shape {
111 }
112 float_val: 4.0
113 float_val: 5.0
114 }
115 }
116 }
117 }
118 node {
119 name: "Add"
120 op: "Add"
121 input: "Const_1"
122 input: "graphInput"
123 attr {
124 key: "T"
125 value {
126 type: DT_FLOAT
127 }
128 }
129 }
130 )";
131
132 SetupSingleInputSingleOutput({ 1, 2, 2, 2 }, "graphInput", "Add");
133 }
134};
135
136
137BOOST_FIXTURE_TEST_CASE(ParseBroadcastForAddition1, BroadcastForAddFixtureSlot1)
138{
139 RunTest<4>({ 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0 }, { 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0 });
140}
141
142BOOST_FIXTURE_TEST_CASE(ParseBroadcastForAddition0, BroadcastForAddFixtureSlot0)
143{
144 RunTest<4>({ 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0 }, { 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0 });
145}
146
147
148
149BOOST_AUTO_TEST_SUITE_END()