blob: 6f5016265a1a8e62ddfc9f2be3963bcaade4075a [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersFixture.hpp"
8#include "../TfLiteParser.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
11
12struct EmptyNetworkFixture : public ParserFlatbuffersFixture
13{
14 explicit EmptyNetworkFixture() {
15 m_JsonString = R"(
16 {
17 "version": 3,
18 "operator_codes": [],
19 "subgraphs": [ {} ]
20 })";
21 }
22};
23
24BOOST_FIXTURE_TEST_CASE(EmptyNetworkHasNoInputsAndOutputs, EmptyNetworkFixture)
25{
26 Setup();
27 BOOST_TEST(m_Parser->GetSubgraphCount() == 1);
28 BOOST_TEST(m_Parser->GetSubgraphInputTensorNames(0).size() == 0);
29 BOOST_TEST(m_Parser->GetSubgraphOutputTensorNames(0).size() == 0);
30}
31
32struct MissingTensorsFixture : public ParserFlatbuffersFixture
33{
34 explicit MissingTensorsFixture()
35 {
36 m_JsonString = R"(
37 {
38 "version": 3,
39 "operator_codes": [],
40 "subgraphs": [{
41 "inputs" : [ 0, 1 ],
42 "outputs" : [ 2, 3 ],
43 }]
44 })";
45 }
46};
47
48BOOST_FIXTURE_TEST_CASE(MissingTensorsThrowException, MissingTensorsFixture)
49{
50 // this throws because it cannot do the input output tensor connections
51 BOOST_CHECK_THROW(Setup(), armnn::ParseException);
52}
53
54struct InvalidTensorsFixture : public ParserFlatbuffersFixture
55{
56 explicit InvalidTensorsFixture()
57 {
58 m_JsonString = R"(
59 {
60 "version": 3,
61 "operator_codes": [ ],
62 "subgraphs": [{
63 "tensors": [ {}, {}, {}, {} ],
64 "inputs" : [ 0, 1 ],
65 "outputs" : [ 2, 3 ],
66 }]
67 })";
68 }
69};
70
71BOOST_FIXTURE_TEST_CASE(InvalidTensorsThrowException, InvalidTensorsFixture)
72{
73 // this throws because it cannot do the input output tensor connections
74 BOOST_CHECK_THROW(Setup(), armnn::InvalidArgumentException);
75}
76
77struct ValidTensorsFixture : public ParserFlatbuffersFixture
78{
79 explicit ValidTensorsFixture()
80 {
81 m_JsonString = R"(
82 {
83 "version": 3,
84 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" } ],
85 "subgraphs": [{
86 "tensors": [ {
87 "shape": [ 1, 1, 1, 1 ],
88 "type": "FLOAT32",
89 "name": "In",
90 "buffer": 0,
91 }, {
92 "shape": [ 1, 1, 1, 1 ],
93 "type": "FLOAT32",
94 "name": "Out",
95 "buffer": 1,
96 }],
97 "inputs" : [ 0 ],
98 "outputs" : [ 1 ],
99 "operators": [{
100 "opcode_index": 0,
101 "inputs": [ 0 ],
102 "outputs": [ 1 ],
103 "builtin_options_type": "Pool2DOptions",
104 "builtin_options":
105 {
106 "padding": "VALID",
107 "stride_w": 1,
108 "stride_h": 1,
109 "filter_width": 1,
110 "filter_height": 1,
111 "fused_activation_function": "NONE"
112 },
113 "custom_options_format": "FLEXBUFFERS"
114 }]
115 }]
116 })";
117 }
118};
119
120BOOST_FIXTURE_TEST_CASE(GetValidInputOutputTensorNames, ValidTensorsFixture)
121{
122 Setup();
123 BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0).size(), 1u);
124 BOOST_CHECK_EQUAL(m_Parser->GetSubgraphOutputTensorNames(0).size(), 1u);
125 BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0)[0], "In");
126 BOOST_CHECK_EQUAL(m_Parser->GetSubgraphOutputTensorNames(0)[0], "Out");
127}
128
129BOOST_FIXTURE_TEST_CASE(ThrowIfSubgraphIdInvalidForInOutNames, ValidTensorsFixture)
130{
131 Setup();
132
133 // these throw because of the invalid subgraph id
134 BOOST_CHECK_THROW(m_Parser->GetSubgraphInputTensorNames(1), armnn::ParseException);
135 BOOST_CHECK_THROW(m_Parser->GetSubgraphOutputTensorNames(1), armnn::ParseException);
136}
137
138BOOST_AUTO_TEST_SUITE_END()