blob: a8c99793ad109ec59d09749bf79e1aaaa24d7874 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersFixture.hpp"
8#include "../TfLiteParser.hpp"
9
10#include <string>
11#include <iostream>
12
13BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14
15struct SqueezeFixture : public ParserFlatbuffersFixture
16{
17 explicit SqueezeFixture(const std::string& inputShape,
18 const std::string& outputShape,
19 const std::string& squeezeDims)
20 {
21 m_JsonString = R"(
22 {
23 "version": 3,
24 "operator_codes": [ { "builtin_code": "SQUEEZE" } ],
25 "subgraphs": [ {
26 "tensors": [
27 {)";
28 m_JsonString += R"(
29 "shape" : )" + inputShape + ",";
30 m_JsonString += R"(
31 "type": "UINT8",
32 "buffer": 0,
33 "name": "inputTensor",
34 "quantization": {
35 "min": [ 0.0 ],
36 "max": [ 255.0 ],
37 "scale": [ 1.0 ],
38 "zero_point": [ 0 ],
39 }
40 },
41 {)";
42 m_JsonString += R"(
43 "shape" : )" + outputShape;
44 m_JsonString += R"(,
45 "type": "UINT8",
46 "buffer": 1,
47 "name": "outputTensor",
48 "quantization": {
49 "min": [ 0.0 ],
50 "max": [ 255.0 ],
51 "scale": [ 1.0 ],
52 "zero_point": [ 0 ],
53 }
54 }
55 ],
56 "inputs": [ 0 ],
57 "outputs": [ 1 ],
58 "operators": [
59 {
60 "opcode_index": 0,
61 "inputs": [ 0 ],
62 "outputs": [ 1 ],
63 "builtin_options_type": "SqueezeOptions",
64 "builtin_options": {)";
65 if (!squeezeDims.empty())
66 {
67 m_JsonString += R"("squeeze_dims" : )" + squeezeDims;
68 }
69 m_JsonString += R"(},
70 "custom_options_format": "FLEXBUFFERS"
71 }
72 ],
73 } ],
74 "buffers" : [ {}, {} ]
75 }
76 )";
77 }
78};
79
80struct SqueezeFixtureWithSqueezeDims : SqueezeFixture
81{
82 SqueezeFixtureWithSqueezeDims() : SqueezeFixture("[ 1, 2, 2, 1 ]", "[ 2, 2, 1 ]", "[ 0, 1, 2 ]") {}
83};
84
85BOOST_FIXTURE_TEST_CASE(ParseSqueezeWithSqueezeDims, SqueezeFixtureWithSqueezeDims)
86{
87 SetupSingleInputSingleOutput("inputTensor", "outputTensor");
88 RunTest<3, uint8_t>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
89 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
90 == armnn::TensorShape({2,2,1})));
91
92}
93
94struct SqueezeFixtureWithoutSqueezeDims : SqueezeFixture
95{
96 SqueezeFixtureWithoutSqueezeDims() : SqueezeFixture("[ 1, 2, 2, 1 ]", "[ 2, 2 ]", "") {}
97};
98
99BOOST_FIXTURE_TEST_CASE(ParseSqueezeWithoutSqueezeDims, SqueezeFixtureWithoutSqueezeDims)
100{
101 SetupSingleInputSingleOutput("inputTensor", "outputTensor");
102 RunTest<2, uint8_t>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
103 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
104 == armnn::TensorShape({2,2})));
105}
106
107struct SqueezeFixtureWithInvalidInput : SqueezeFixture
108{
109 SqueezeFixtureWithInvalidInput() : SqueezeFixture("[ 1, 2, 2, 1, 2 ]", "[ 1, 2, 2, 1 ]", "[ ]") {}
110};
111
112BOOST_FIXTURE_TEST_CASE(ParseSqueezeInvalidInput, SqueezeFixtureWithInvalidInput)
113{
114 BOOST_CHECK_THROW((SetupSingleInputSingleOutput("inputTensor", "outputTensor")),
115 armnn::InvalidArgumentException);
116}
117
118struct SqueezeFixtureWithSqueezeDimsSizeInvalid : SqueezeFixture
119{
120 SqueezeFixtureWithSqueezeDimsSizeInvalid() : SqueezeFixture("[ 1, 2, 2, 1 ]",
121 "[ 1, 2, 2, 1 ]",
122 "[ 1, 2, 2, 2, 2 ]") {}
123};
124
125BOOST_FIXTURE_TEST_CASE(ParseSqueezeInvalidSqueezeDims, SqueezeFixtureWithSqueezeDimsSizeInvalid)
126{
127 BOOST_CHECK_THROW((SetupSingleInputSingleOutput("inputTensor", "outputTensor")), armnn::ParseException);
128}
129
130
131struct SqueezeFixtureWithNegativeSqueezeDims : SqueezeFixture
132{
133 SqueezeFixtureWithNegativeSqueezeDims() : SqueezeFixture("[ 1, 2, 2, 1 ]",
134 "[ 1, 2, 2, 1 ]",
135 "[ -2 , 2 ]") {}
136};
137
138BOOST_FIXTURE_TEST_CASE(ParseSqueezeNegativeSqueezeDims, SqueezeFixtureWithNegativeSqueezeDims)
139{
140 BOOST_CHECK_THROW((SetupSingleInputSingleOutput("inputTensor", "outputTensor")), armnn::ParseException);
141}
142
143
144BOOST_AUTO_TEST_SUITE_END()