blob: bf42bf7c5d76ea11657a94b5f849dbeaa7c4b28e [file] [log] [blame]
Sadik Armagan2ad6cb42018-12-27 11:23:44 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12struct SplitFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13{
Saoirse Stewart91c0eff2019-02-27 11:07:57 +000014 SplitFixture(bool withDimZero=false) {
15 m_Prototext = R"(
16 node {
17 name: "graphInput"
18 op: "Placeholder"
19 attr {
20 key: "dtype"
21 value {
22 type: DT_FLOAT
23 }
24 }
25 attr {
26 key: "shape"
27 value {
28 shape {
29 }
30 }
31 }
32 }
33 node {
34 name: "graphInput2"
35 op: "Placeholder"
36 attr {
37 key: "dtype"
38 value {
39 type: DT_FLOAT
40 }
41 }
42 attr {
43 key: "shape"
44 value {
45 shape {
46 }
47 }
48 }
49 }
50 node {
51 name: "multiplication"
52 op : "Mul"
53 input: "graphInput"
54 input: "graphInput2"
55 attr {
56 key: "T"
57 value {
58 type: DT_FLOAT
59 }
60 }
61 }
62 node {
63 name: "SplitInput"
64 op: "Const"
65 attr {
66 key: "dtype"
67 value {
68 type: DT_INT32
69 }
70 }
71 attr {
72 key: "value"
73 value {
74 tensor {
75 dtype: DT_INT32
76 tensor_shape {
77 }
78 int_val: )";
Sadik Armagan2ad6cb42018-12-27 11:23:44 +000079
Saoirse Stewart91c0eff2019-02-27 11:07:57 +000080 if(withDimZero)
81 {
82 m_Prototext += std::to_string(3);
83 }
84 else
85 {
86 m_Prototext += std::to_string(1);
87 }
88
89 m_Prototext += R"(
90 }
91 }
92 }
93 }
94 node {
95 name: "Split"
96 op: "Split" )";
97 if(withDimZero)
98 {
99 m_Prototext += "input: \"SplitInput\"\n";
100 m_Prototext += "input: \"multiplication\"\n";
101 }
102 else
103 {
104 m_Prototext += "input: \"graphInput\"\n";
105 m_Prototext += "input: \"SplitInput\"\n";
106 }
107 m_Prototext += R"(
108 attr {
Saoirse Stewart120196d2019-02-28 11:32:41 +0000109 key: "num_split"
Saoirse Stewart91c0eff2019-02-27 11:07:57 +0000110 value {
111 i: 2
112 }
113 }
114 }
115 node {
116 name: "Relu_1"
117 op: "Relu"
118 input: "Split:0"
119 attr {
120 key: "T"
121 value {
122 type: DT_FLOAT
123 }
124 }
125 }
126 node {
127 name: "Relu_2"
128 op: "Relu"
129 input:"Split:1"
130 attr {
131 key: "T"
132 value {
133 type: DT_FLOAT
134 }
135 }
136 } )";
137
138 Setup( { { "graphInput", { 1, 2, 2 , 2} } , { "graphInput2", { 1, 2, 2 , 2} }},
Sadik Armagan2ad6cb42018-12-27 11:23:44 +0000139 { "Relu_1", "Relu_2" });
140 }
141};
142
Saoirse Stewart91c0eff2019-02-27 11:07:57 +0000143struct InputFirstSplitFixture : SplitFixture
144{
145 InputFirstSplitFixture() : SplitFixture(true) {}
146};
147
Sadik Armagan2ad6cb42018-12-27 11:23:44 +0000148BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture)
149{
150 BOOST_TEST(
151 (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
152
153 BOOST_TEST(
154 (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
155
156 RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f, 1.75f } } },
157 { { "Relu_1", { 0.0f, 0.0f, 1.25f, 0.0f } },
158 { "Relu_2", { 0.0f, 0.5f, 0.0f, 1.75f } } });
159}
160
Saoirse Stewart91c0eff2019-02-27 11:07:57 +0000161BOOST_FIXTURE_TEST_CASE(ParseSplit, InputFirstSplitFixture)
162{
163
164 BOOST_TEST(
165 (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
166
167 BOOST_TEST(
168 (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
169
170 RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } ,
171 { "graphInput2", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } },
172 { { "Relu_1", { 1.0f, 1.5625f, 0, 0.5625f } },
173 { "Relu_2", { 0.25, 9.0f, 0.25f, 3.0625f } } });
174}
175
Sadik Armagan2ad6cb42018-12-27 11:23:44 +0000176BOOST_AUTO_TEST_SUITE_END()