blob: 31fbfbf6a9d6c04665a5a96b8b560319c73568a2 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
Finn Williamsb49ed182021-06-29 15:50:08 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
Sadik Armagan1625efc2021-06-10 18:24:34 +01005
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "ParserFlatbuffersFixture.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01007
Kevin May7d96b162021-02-03 17:38:41 +00008using armnnTfLiteParser::TfLiteParserImpl;
9using ModelPtr = TfLiteParserImpl::ModelPtr;
telsoa01c577f2c2018-08-31 09:22:23 +010010
Sadik Armagan1625efc2021-06-10 18:24:34 +010011TEST_SUITE("TensorflowLiteParser_GetTensorIds")
12{
telsoa01c577f2c2018-08-31 09:22:23 +010013struct GetTensorIdsFixture : public ParserFlatbuffersFixture
14{
15 explicit GetTensorIdsFixture(const std::string& inputs, const std::string& outputs)
16 {
17 m_JsonString = R"(
18 {
19 "version": 3,
20 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" } ],
21 "subgraphs": [
22 {
23 "tensors": [
24 {
25 "shape": [ 1, 1, 1, 1 ] ,
26 "type": "UINT8",
27 "buffer": 0,
28 "name": "OutputTensor",
29 "quantization": {
30 "min": [ 0.0 ],
31 "max": [ 255.0 ],
32 "scale": [ 1.0 ],
33 "zero_point": [ 0 ]
34 }
35 },
36 {
37 "shape": [ 1, 2, 2, 1 ] ,
38 "type": "UINT8",
39 "buffer": 1,
40 "name": "InputTensor",
41 "quantization": {
42 "min": [ 0.0 ],
43 "max": [ 255.0 ],
44 "scale": [ 1.0 ],
45 "zero_point": [ 0 ]
46 }
47 }
48 ],
49 "inputs": [ 1 ],
50 "outputs": [ 0 ],
51 "operators": [ {
52 "opcode_index": 0,
53 "inputs": )"
54 + inputs
55 + R"(,
56 "outputs": )"
57 + outputs
58 + R"(,
59 "builtin_options_type": "Pool2DOptions",
60 "builtin_options":
61 {
62 "padding": "VALID",
63 "stride_w": 2,
64 "stride_h": 2,
65 "filter_width": 2,
66 "filter_height": 2,
67 "fused_activation_function": "NONE"
68 },
69 "custom_options_format": "FLEXBUFFERS"
70 } ]
71 }
72 ],
73 "description": "Test loading a model",
74 "buffers" : [ {}, {} ]
75 })";
76
77 ReadStringToBinary();
78 }
79};
80
81struct GetEmptyTensorIdsFixture : GetTensorIdsFixture
82{
83 GetEmptyTensorIdsFixture() : GetTensorIdsFixture("[ ]", "[ ]") {}
84};
85
86struct GetInputOutputTensorIdsFixture : GetTensorIdsFixture
87{
88 GetInputOutputTensorIdsFixture() : GetTensorIdsFixture("[ 0, 1, 2 ]", "[ 3 ]") {}
89};
90
Sadik Armagan1625efc2021-06-10 18:24:34 +010091TEST_CASE_FIXTURE(GetEmptyTensorIdsFixture, "GetEmptyInputTensorIds")
telsoa01c577f2c2018-08-31 09:22:23 +010092{
Kevin May7d96b162021-02-03 17:38:41 +000093 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
94 m_GraphBinary.size());
telsoa01c577f2c2018-08-31 09:22:23 +010095 std::vector<int32_t> expectedIds = { };
Kevin May7d96b162021-02-03 17:38:41 +000096 std::vector<int32_t> inputTensorIds = TfLiteParserImpl::GetInputTensorIds(model, 0, 0);
Sadik Armagan1625efc2021-06-10 18:24:34 +010097 CHECK(std::equal(expectedIds.begin(), expectedIds.end(),
98 inputTensorIds.begin(), inputTensorIds.end()));
telsoa01c577f2c2018-08-31 09:22:23 +010099}
100
Sadik Armagan1625efc2021-06-10 18:24:34 +0100101TEST_CASE_FIXTURE(GetEmptyTensorIdsFixture, "GetEmptyOutputTensorIds")
telsoa01c577f2c2018-08-31 09:22:23 +0100102{
Kevin May7d96b162021-02-03 17:38:41 +0000103 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
104 m_GraphBinary.size());
telsoa01c577f2c2018-08-31 09:22:23 +0100105 std::vector<int32_t> expectedIds = { };
Kevin May7d96b162021-02-03 17:38:41 +0000106 std::vector<int32_t> outputTensorIds = TfLiteParserImpl::GetOutputTensorIds(model, 0, 0);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100107 CHECK(std::equal(expectedIds.begin(), expectedIds.end(),
108 outputTensorIds.begin(), outputTensorIds.end()));
telsoa01c577f2c2018-08-31 09:22:23 +0100109}
110
Sadik Armagan1625efc2021-06-10 18:24:34 +0100111TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetInputTensorIds")
telsoa01c577f2c2018-08-31 09:22:23 +0100112{
Kevin May7d96b162021-02-03 17:38:41 +0000113 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
114 m_GraphBinary.size());
telsoa01c577f2c2018-08-31 09:22:23 +0100115 std::vector<int32_t> expectedInputIds = { 0, 1, 2 };
Kevin May7d96b162021-02-03 17:38:41 +0000116 std::vector<int32_t> inputTensorIds = TfLiteParserImpl::GetInputTensorIds(model, 0, 0);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100117 CHECK(std::equal(expectedInputIds.begin(), expectedInputIds.end(),
118 inputTensorIds.begin(), inputTensorIds.end()));
telsoa01c577f2c2018-08-31 09:22:23 +0100119}
120
Sadik Armagan1625efc2021-06-10 18:24:34 +0100121TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetOutputTensorIds")
telsoa01c577f2c2018-08-31 09:22:23 +0100122{
Kevin May7d96b162021-02-03 17:38:41 +0000123 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
124 m_GraphBinary.size());
telsoa01c577f2c2018-08-31 09:22:23 +0100125 std::vector<int32_t> expectedOutputIds = { 3 };
Kevin May7d96b162021-02-03 17:38:41 +0000126 std::vector<int32_t> outputTensorIds = TfLiteParserImpl::GetOutputTensorIds(model, 0, 0);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100127 CHECK(std::equal(expectedOutputIds.begin(), expectedOutputIds.end(),
128 outputTensorIds.begin(), outputTensorIds.end()));
telsoa01c577f2c2018-08-31 09:22:23 +0100129}
130
Sadik Armagan1625efc2021-06-10 18:24:34 +0100131TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetInputTensorIdsNullModel")
telsoa01c577f2c2018-08-31 09:22:23 +0100132{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100133 CHECK_THROWS_AS(TfLiteParserImpl::GetInputTensorIds(nullptr, 0, 0), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100134}
135
Sadik Armagan1625efc2021-06-10 18:24:34 +0100136TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetOutputTensorIdsNullModel")
telsoa01c577f2c2018-08-31 09:22:23 +0100137{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100138 CHECK_THROWS_AS(TfLiteParserImpl::GetOutputTensorIds(nullptr, 0, 0), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100139}
140
Sadik Armagan1625efc2021-06-10 18:24:34 +0100141TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetInputTensorIdsInvalidSubgraph")
telsoa01c577f2c2018-08-31 09:22:23 +0100142{
Kevin May7d96b162021-02-03 17:38:41 +0000143 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
144 m_GraphBinary.size());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100145 CHECK_THROWS_AS(TfLiteParserImpl::GetInputTensorIds(model, 1, 0), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100146}
147
Sadik Armagan1625efc2021-06-10 18:24:34 +0100148TEST_CASE_FIXTURE( GetInputOutputTensorIdsFixture, "GetOutputTensorIdsInvalidSubgraph")
telsoa01c577f2c2018-08-31 09:22:23 +0100149{
Kevin May7d96b162021-02-03 17:38:41 +0000150 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
151 m_GraphBinary.size());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100152 CHECK_THROWS_AS(TfLiteParserImpl::GetOutputTensorIds(model, 1, 0), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100153}
154
Sadik Armagan1625efc2021-06-10 18:24:34 +0100155TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetInputTensorIdsInvalidOperator")
telsoa01c577f2c2018-08-31 09:22:23 +0100156{
Kevin May7d96b162021-02-03 17:38:41 +0000157 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
158 m_GraphBinary.size());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100159 CHECK_THROWS_AS(TfLiteParserImpl::GetInputTensorIds(model, 0, 1), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100160}
161
Sadik Armagan1625efc2021-06-10 18:24:34 +0100162TEST_CASE_FIXTURE(GetInputOutputTensorIdsFixture, "GetOutputTensorIdsInvalidOperator")
telsoa01c577f2c2018-08-31 09:22:23 +0100163{
Kevin May7d96b162021-02-03 17:38:41 +0000164 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
165 m_GraphBinary.size());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100166 CHECK_THROWS_AS(TfLiteParserImpl::GetOutputTensorIds(model, 0, 1), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100167}
168
Sadik Armagan1625efc2021-06-10 18:24:34 +0100169}