blob: de6b5d861e17536d6fef193443657e3a7f7b77d6 [file] [log] [blame]
Sadik Armagan2ad6cb42018-12-27 11:23:44 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12struct SplitFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13{
14 SplitFixture() {
15 m_Prototext =
16 "node { \n"
17 " name: \"graphInput\" \n"
18 " op: \"Placeholder\" \n"
19 " attr { \n"
20 " key: \"dtype\" \n"
21 " value { \n"
22 " type: DT_FLOAT \n"
23 " } \n"
24 " } \n"
25 " attr { \n"
26 " key: \"shape\" \n"
27 " value { \n"
28 " shape { \n"
29 " } \n"
30 " } \n"
31 " } \n"
32 " } \n"
33 " node {"
34 " name: \"splitInput\" \n"
35 " op: \"Const\" \n"
36 "attr {\n"
37 " key: \"dtype\" \n"
38 " value {"
39 " type: DT_INT32"
40 " }"
41 "}"
42 "attr {"
43 " key: \"value\"\n"
44 " value { "
45 " tensor {"
46 " dtype: DT_INT32"
47 " tensor_shape {"
48 "}"
49 "int_val: 1"
50 "}"
51 "}"
52 "}"
53 "}"
54 "node { \n"
55 " name: \"Split\" \n"
56 " op: \"Split\" \n"
57 "input: \"graphInput\"\n"
58 "input: \"splitInput\"\n"
59 "attr { \n "
60 "key: \"T\"\n"
61 "value {\n"
62 "type: DT_FLOAT\n"
63 " }\n"
64 "}\n"
65 "\n"
66 " attr { \n"
67 " key: \"num_or_size_splits\" \n"
68 " value { \n"
69 " i:2 \n "
70 " } \n"
71 " } \n"
72 "} \n"
73 "node { \n"
74 "name: \"Relu_1\"\n"
75 "op: \"Relu\"\n"
76 "input: \"Split:0\"\n"
77 "attr { \n "
78 "key: \"T\"\n"
79 "value {\n"
80 "type: DT_FLOAT\n"
81 " }\n"
82 "}\n"
83 "}\n"
84 "node { \n"
85 "name: \"Relu_2\"\n"
86 "op: \"Relu\"\n"
87 "input: \"Split:1\"\n"
88 "attr { \n "
89 "key: \"T\"\n"
90 "value {\n"
91 "type: DT_FLOAT\n"
92 " }\n"
93 "}\n"
94 "}\n";
95
96 Setup( { { "graphInput", { 1, 2, 2 , 2} } },
97 { "Relu_1", "Relu_2" });
98 }
99};
100
101BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture)
102{
103 BOOST_TEST(
104 (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
105
106 BOOST_TEST(
107 (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
108
109 RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f, 1.75f } } },
110 { { "Relu_1", { 0.0f, 0.0f, 1.25f, 0.0f } },
111 { "Relu_2", { 0.0f, 0.5f, 0.0f, 1.75f } } });
112}
113
114BOOST_AUTO_TEST_SUITE_END()