blob: d42ae2e438e535ce82efe7bdcb999ac78aaf2288 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersFixture.hpp"
8#include "../TfLiteParser.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
11
12struct EmptyNetworkFixture : public ParserFlatbuffersFixture
13{
14 explicit EmptyNetworkFixture() {
15 m_JsonString = R"(
16 {
17 "version": 3,
18 "operator_codes": [],
19 "subgraphs": [ {} ]
20 })";
21 }
22};
23
24BOOST_FIXTURE_TEST_CASE(EmptyNetworkHasNoInputsAndOutputs, EmptyNetworkFixture)
25{
26 Setup();
27 BOOST_TEST(m_Parser->GetSubgraphCount() == 1);
28 BOOST_TEST(m_Parser->GetSubgraphInputTensorNames(0).size() == 0);
29 BOOST_TEST(m_Parser->GetSubgraphOutputTensorNames(0).size() == 0);
30}
31
32struct MissingTensorsFixture : public ParserFlatbuffersFixture
33{
34 explicit MissingTensorsFixture()
35 {
36 m_JsonString = R"(
37 {
38 "version": 3,
39 "operator_codes": [],
40 "subgraphs": [{
41 "inputs" : [ 0, 1 ],
42 "outputs" : [ 2, 3 ],
43 }]
44 })";
45 }
46};
47
48BOOST_FIXTURE_TEST_CASE(MissingTensorsThrowException, MissingTensorsFixture)
49{
50 // this throws because it cannot do the input output tensor connections
51 BOOST_CHECK_THROW(Setup(), armnn::ParseException);
52}
53
54struct InvalidTensorsFixture : public ParserFlatbuffersFixture
55{
56 explicit InvalidTensorsFixture()
57 {
58 m_JsonString = R"(
59 {
60 "version": 3,
61 "operator_codes": [ ],
62 "subgraphs": [{
Narumol Prangnawarat4818d462019-04-17 11:22:38 +010063 "tensors": [ {
64 "shape": [ 1, 1, 1, 1, 1 ],
65 "type": "FLOAT32",
66 "name": "In",
67 "buffer": 0
68 }, {
69 "shape": [ 1, 1, 1, 1, 1 ],
70 "type": "FLOAT32",
71 "name": "Out",
72 "buffer": 1
73 }],
74 "inputs" : [ 0 ],
75 "outputs" : [ 1 ],
telsoa01c577f2c2018-08-31 09:22:23 +010076 }]
77 })";
78 }
79};
80
81BOOST_FIXTURE_TEST_CASE(InvalidTensorsThrowException, InvalidTensorsFixture)
82{
Narumol Prangnawarat4818d462019-04-17 11:22:38 +010083 // Tensor numDimensions must be less than or equal to MaxNumOfTensorDimensions
telsoa01c577f2c2018-08-31 09:22:23 +010084 BOOST_CHECK_THROW(Setup(), armnn::InvalidArgumentException);
85}
86
87struct ValidTensorsFixture : public ParserFlatbuffersFixture
88{
89 explicit ValidTensorsFixture()
90 {
91 m_JsonString = R"(
92 {
93 "version": 3,
94 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" } ],
95 "subgraphs": [{
96 "tensors": [ {
97 "shape": [ 1, 1, 1, 1 ],
98 "type": "FLOAT32",
99 "name": "In",
100 "buffer": 0,
101 }, {
102 "shape": [ 1, 1, 1, 1 ],
103 "type": "FLOAT32",
104 "name": "Out",
105 "buffer": 1,
106 }],
107 "inputs" : [ 0 ],
108 "outputs" : [ 1 ],
109 "operators": [{
110 "opcode_index": 0,
111 "inputs": [ 0 ],
112 "outputs": [ 1 ],
113 "builtin_options_type": "Pool2DOptions",
114 "builtin_options":
115 {
116 "padding": "VALID",
117 "stride_w": 1,
118 "stride_h": 1,
119 "filter_width": 1,
120 "filter_height": 1,
121 "fused_activation_function": "NONE"
122 },
123 "custom_options_format": "FLEXBUFFERS"
124 }]
125 }]
126 })";
127 }
128};
129
130BOOST_FIXTURE_TEST_CASE(GetValidInputOutputTensorNames, ValidTensorsFixture)
131{
132 Setup();
133 BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0).size(), 1u);
134 BOOST_CHECK_EQUAL(m_Parser->GetSubgraphOutputTensorNames(0).size(), 1u);
135 BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0)[0], "In");
136 BOOST_CHECK_EQUAL(m_Parser->GetSubgraphOutputTensorNames(0)[0], "Out");
137}
138
139BOOST_FIXTURE_TEST_CASE(ThrowIfSubgraphIdInvalidForInOutNames, ValidTensorsFixture)
140{
141 Setup();
142
143 // these throw because of the invalid subgraph id
144 BOOST_CHECK_THROW(m_Parser->GetSubgraphInputTensorNames(1), armnn::ParseException);
145 BOOST_CHECK_THROW(m_Parser->GetSubgraphOutputTensorNames(1), armnn::ParseException);
146}
147
Narumol Prangnawarat4818d462019-04-17 11:22:38 +0100148struct Rank0TensorFixture : public ParserFlatbuffersFixture
149{
150 explicit Rank0TensorFixture()
151 {
152 m_JsonString = R"(
153 {
154 "version": 3,
155 "operator_codes": [ { "builtin_code": "MINIMUM" } ],
156 "subgraphs": [{
157 "tensors": [ {
158 "shape": [ ],
159 "type": "FLOAT32",
160 "name": "In0",
161 "buffer": 0,
162 }, {
163 "shape": [ ],
164 "type": "FLOAT32",
165 "name": "In1",
166 "buffer": 1,
167 }, {
168 "shape": [ ],
169 "type": "FLOAT32",
170 "name": "Out",
171 "buffer": 2,
172 }],
173 "inputs" : [ 0, 1 ],
174 "outputs" : [ 2 ],
175 "operators": [{
176 "opcode_index": 0,
177 "inputs": [ 0, 1 ],
178 "outputs": [ 2 ],
179 "custom_options_format": "FLEXBUFFERS"
180 }]
181 }]
182 }
183 )";
184 }
185};
186
187BOOST_FIXTURE_TEST_CASE(Rank0Tensor, Rank0TensorFixture)
188{
189 Setup();
190 BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0).size(), 2u);
191 BOOST_CHECK_EQUAL(m_Parser->GetSubgraphOutputTensorNames(0).size(), 1u);
192 BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0)[0], "In0");
193 BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0)[1], "In1");
194 BOOST_CHECK_EQUAL(m_Parser->GetSubgraphOutputTensorNames(0)[0], "Out");
195}
196
telsoa01c577f2c2018-08-31 09:22:23 +0100197BOOST_AUTO_TEST_SUITE_END()