blob: 5c64449c3475e9bf5de490051387441f800360f4 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
Sadik Armagan1625efc2021-06-10 18:24:34 +01005
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "ParserFlatbuffersFixture.hpp"
7#include "../TfLiteParser.hpp"
8
Kevin May7d96b162021-02-03 17:38:41 +00009using armnnTfLiteParser::TfLiteParserImpl;
10using ModelPtr = TfLiteParserImpl::ModelPtr;
11using TensorRawPtr = TfLiteParserImpl::TensorRawPtr;
telsoa01c577f2c2018-08-31 09:22:23 +010012
Sadik Armagan1625efc2021-06-10 18:24:34 +010013TEST_SUITE("TensorflowLiteParser_GetSubgraphInputsOutputs")
14{
telsoa01c577f2c2018-08-31 09:22:23 +010015struct GetSubgraphInputsOutputsMainFixture : public ParserFlatbuffersFixture
16{
17 explicit GetSubgraphInputsOutputsMainFixture(const std::string& inputs, const std::string& outputs)
18 {
19 m_JsonString = R"(
20 {
21 "version": 3,
22 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
23 "subgraphs": [
24 {
25 "tensors": [
26 {
27 "shape": [ 1, 1, 1, 1 ] ,
28 "type": "UINT8",
29 "buffer": 0,
30 "name": "OutputTensor",
31 "quantization": {
32 "min": [ 0.0 ],
33 "max": [ 255.0 ],
34 "scale": [ 1.0 ],
35 "zero_point": [ 0 ]
36 }
37 },
38 {
39 "shape": [ 1, 2, 2, 1 ] ,
40 "type": "UINT8",
41 "buffer": 1,
42 "name": "InputTensor",
43 "quantization": {
44 "min": [ -1.2 ],
45 "max": [ 25.5 ],
46 "scale": [ 0.25 ],
47 "zero_point": [ 10 ]
48 }
49 }
50 ],
51 "inputs": )"
52 + inputs
53 + R"(,
54 "outputs": )"
55 + outputs
56 + R"(,
57 "operators": [ {
58 "opcode_index": 0,
59 "inputs": [ 1 ],
60 "outputs": [ 0 ],
61 "builtin_options_type": "Pool2DOptions",
62 "builtin_options":
63 {
64 "padding": "VALID",
65 "stride_w": 2,
66 "stride_h": 2,
67 "filter_width": 2,
68 "filter_height": 2,
69 "fused_activation_function": "NONE"
70 },
71 "custom_options_format": "FLEXBUFFERS"
72 } ]
73 },
74 {
75 "tensors": [
76 {
77 "shape": [ 1, 3, 3, 1 ],
78 "type": "UINT8",
79 "buffer": 0,
80 "name": "ConvInputTensor",
81 "quantization": {
82 "scale": [ 1.0 ],
83 "zero_point": [ 0 ],
84 }
85 },
86 {
87 "shape": [ 1, 1, 1, 1 ],
88 "type": "UINT8",
89 "buffer": 1,
90 "name": "ConvOutputTensor",
91 "quantization": {
92 "min": [ 0.0 ],
93 "max": [ 511.0 ],
94 "scale": [ 2.0 ],
95 "zero_point": [ 0 ],
96 }
97 },
98 {
99 "shape": [ 1, 3, 3, 1 ],
100 "type": "UINT8",
101 "buffer": 2,
102 "name": "filterTensor",
103 "quantization": {
104 "min": [ 0.0 ],
105 "max": [ 255.0 ],
106 "scale": [ 1.0 ],
107 "zero_point": [ 0 ],
108 }
109 }
110 ],
111 "inputs": [ 0 ],
112 "outputs": [ 1 ],
113 "operators": [
114 {
115 "opcode_index": 0,
116 "inputs": [ 0, 2 ],
117 "outputs": [ 1 ],
118 "builtin_options_type": "Conv2DOptions",
119 "builtin_options": {
120 "padding": "VALID",
121 "stride_w": 1,
122 "stride_h": 1,
123 "fused_activation_function": "NONE"
124 },
125 "custom_options_format": "FLEXBUFFERS"
126 }
127 ],
128 }
129 ],
130 "description": "Test Subgraph Inputs Outputs",
131 "buffers" : [
132 { },
133 { },
134 { "data": [ 2,1,0, 6,2,1, 4,1,2 ], },
135 { },
136 ]
137 })";
138
139 ReadStringToBinary();
140 }
141
142};
143
144struct GetEmptySubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture
145{
146 GetEmptySubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ ]", "[ ]") {}
147};
148
149struct GetSubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture
150{
151 GetSubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ 1 ]", "[ 0 ]") {}
152};
153
Sadik Armagan1625efc2021-06-10 18:24:34 +0100154TEST_CASE_FIXTURE(GetEmptySubgraphInputsOutputsFixture, "GetEmptySubgraphInputs")
telsoa01c577f2c2018-08-31 09:22:23 +0100155{
Kevin May7d96b162021-02-03 17:38:41 +0000156 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
157 m_GraphBinary.size());
158 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 0);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100159 CHECK_EQ(0, subgraphTensors.size());
telsoa01c577f2c2018-08-31 09:22:23 +0100160}
161
Sadik Armagan1625efc2021-06-10 18:24:34 +0100162TEST_CASE_FIXTURE(GetEmptySubgraphInputsOutputsFixture, "GetEmptySubgraphOutputs")
telsoa01c577f2c2018-08-31 09:22:23 +0100163{
Kevin May7d96b162021-02-03 17:38:41 +0000164 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
165 m_GraphBinary.size());
166 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 0);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100167 CHECK_EQ(0, subgraphTensors.size());
telsoa01c577f2c2018-08-31 09:22:23 +0100168}
169
Sadik Armagan1625efc2021-06-10 18:24:34 +0100170TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphInputs")
telsoa01c577f2c2018-08-31 09:22:23 +0100171{
Kevin May7d96b162021-02-03 17:38:41 +0000172 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
173 m_GraphBinary.size());
174 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 0);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100175 CHECK_EQ(1, subgraphTensors.size());
176 CHECK_EQ(1, subgraphTensors[0].first);
telsoa01c577f2c2018-08-31 09:22:23 +0100177 CheckTensors(subgraphTensors[0].second, 4, { 1, 2, 2, 1 }, tflite::TensorType::TensorType_UINT8, 1,
178 "InputTensor", { -1.2f }, { 25.5f }, { 0.25f }, { 10 });
179}
180
Sadik Armagan1625efc2021-06-10 18:24:34 +0100181TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphOutputsSimpleQuantized")
telsoa01c577f2c2018-08-31 09:22:23 +0100182{
Kevin May7d96b162021-02-03 17:38:41 +0000183 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
184 m_GraphBinary.size());
185 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 0);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100186 CHECK_EQ(1, subgraphTensors.size());
187 CHECK_EQ(0, subgraphTensors[0].first);
telsoa01c577f2c2018-08-31 09:22:23 +0100188 CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 0,
189 "OutputTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 });
190}
191
Sadik Armagan1625efc2021-06-10 18:24:34 +0100192TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphInputsEmptyMinMax")
telsoa01c577f2c2018-08-31 09:22:23 +0100193{
Kevin May7d96b162021-02-03 17:38:41 +0000194 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
195 m_GraphBinary.size());
196 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 1);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100197 CHECK_EQ(1, subgraphTensors.size());
198 CHECK_EQ(0, subgraphTensors[0].first);
telsoa01c577f2c2018-08-31 09:22:23 +0100199 CheckTensors(subgraphTensors[0].second, 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 0,
200 "ConvInputTensor", { }, { }, { 1.0f }, { 0 });
201}
202
Sadik Armagan1625efc2021-06-10 18:24:34 +0100203TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphOutputs")
telsoa01c577f2c2018-08-31 09:22:23 +0100204{
Kevin May7d96b162021-02-03 17:38:41 +0000205 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
206 m_GraphBinary.size());
207 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 1);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100208 CHECK_EQ(1, subgraphTensors.size());
209 CHECK_EQ(1, subgraphTensors[0].first);
telsoa01c577f2c2018-08-31 09:22:23 +0100210 CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 1,
211 "ConvOutputTensor", { 0.0f }, { 511.0f }, { 2.0f }, { 0 });
212}
213
Sadik Armagan1625efc2021-06-10 18:24:34 +0100214TEST_CASE("GetSubgraphInputsNullModel")
telsoa01c577f2c2018-08-31 09:22:23 +0100215{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100216 CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphInputs(nullptr, 0), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100217}
218
Sadik Armagan1625efc2021-06-10 18:24:34 +0100219TEST_CASE("GetSubgraphOutputsNullModel")
telsoa01c577f2c2018-08-31 09:22:23 +0100220{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100221 CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphOutputs(nullptr, 0), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100222}
223
Sadik Armagan1625efc2021-06-10 18:24:34 +0100224TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphInputsInvalidSubgraph")
telsoa01c577f2c2018-08-31 09:22:23 +0100225{
Kevin May7d96b162021-02-03 17:38:41 +0000226 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
227 m_GraphBinary.size());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100228 CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphInputs(model, 2), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100229}
230
Sadik Armagan1625efc2021-06-10 18:24:34 +0100231TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphOutputsInvalidSubgraph")
telsoa01c577f2c2018-08-31 09:22:23 +0100232{
Kevin May7d96b162021-02-03 17:38:41 +0000233 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
234 m_GraphBinary.size());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100235 CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphOutputs(model, 2), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100236}
237
Sadik Armagan1625efc2021-06-10 18:24:34 +0100238}