blob: 57d472d41d4e4063808e92306f44c50723a1ba22 [file] [log] [blame]
Conor Kennedyc2130a02018-12-05 11:05:54 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12struct ExpandDimsFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13{
14 ExpandDimsFixture(const std::string& expandDim)
15 {
16 m_Prototext =
17 "node { \n"
18 " name: \"graphInput\" \n"
19 " op: \"Placeholder\" \n"
20 " attr { \n"
21 " key: \"dtype\" \n"
22 " value { \n"
23 " type: DT_FLOAT \n"
24 " } \n"
25 " } \n"
26 " attr { \n"
27 " key: \"shape\" \n"
28 " value { \n"
29 " shape { \n"
30 " } \n"
31 " } \n"
32 " } \n"
33 " } \n"
34 "node { \n"
35 " name: \"ExpandDims\" \n"
36 " op: \"ExpandDims\" \n"
37 " input: \"graphInput\" \n"
38 " attr { \n"
39 " key: \"T\" \n"
40 " value { \n"
41 " type: DT_FLOAT \n"
42 " } \n"
43 " } \n"
44 " attr { \n"
45 " key: \"Tdim\" \n"
46 " value { \n";
47 m_Prototext += "i:" + expandDim;
48 m_Prototext +=
49 " } \n"
50 " } \n"
51 "} \n";
52
53 SetupSingleInputSingleOutput({ 2, 3, 5 }, "graphInput", "ExpandDims");
54 }
55};
56
57struct ExpandZeroDim : ExpandDimsFixture
58{
59 ExpandZeroDim() : ExpandDimsFixture("0") {}
60};
61
62struct ExpandTwoDim : ExpandDimsFixture
63{
64 ExpandTwoDim() : ExpandDimsFixture("2") {}
65};
66
67struct ExpandThreeDim : ExpandDimsFixture
68{
69 ExpandThreeDim() : ExpandDimsFixture("3") {}
70};
71
72struct ExpandMinusOneDim : ExpandDimsFixture
73{
74 ExpandMinusOneDim() : ExpandDimsFixture("-1") {}
75};
76
77struct ExpandMinusThreeDim : ExpandDimsFixture
78{
79 ExpandMinusThreeDim() : ExpandDimsFixture("-3") {}
80};
81
82BOOST_FIXTURE_TEST_CASE(ParseExpandZeroDim, ExpandZeroDim)
83{
84 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
85 armnn::TensorShape({1, 2, 3, 5})));
86}
87
88BOOST_FIXTURE_TEST_CASE(ParseExpandTwoDim, ExpandTwoDim)
89{
90 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
91 armnn::TensorShape({2, 3, 1, 5})));
92}
93
94BOOST_FIXTURE_TEST_CASE(ParseExpandThreeDim, ExpandThreeDim)
95{
96 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
97 armnn::TensorShape({2, 3, 5, 1})));
98}
99
100BOOST_FIXTURE_TEST_CASE(ParseExpandMinusOneDim, ExpandMinusOneDim)
101{
102 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
103 armnn::TensorShape({2, 3, 5, 1})));
104}
105
106BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim, ExpandMinusThreeDim)
107{
108 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
109 armnn::TensorShape({2, 1, 3, 5})));
110}
111
112BOOST_AUTO_TEST_SUITE_END()