blob: 9f7335f7b51c8d511dccee0ff397c8c44a984fd2 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
Finn Williamsb49ed182021-06-29 15:50:08 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "ParserFlatbuffersFixture.hpp"
7#include "../TfLiteParser.hpp"
8#include <sstream>
9
Kevin May7d96b162021-02-03 17:38:41 +000010using armnnTfLiteParser::TfLiteParserImpl;
telsoa01c577f2c2018-08-31 09:22:23 +010011
Sadik Armagan1625efc2021-06-10 18:24:34 +010012TEST_SUITE("TensorflowLiteParser_GetBuffer")
13{
telsoa01c577f2c2018-08-31 09:22:23 +010014struct GetBufferFixture : public ParserFlatbuffersFixture
15{
16 explicit GetBufferFixture()
17 {
18 m_JsonString = R"(
19 {
20 "version": 3,
21 "operator_codes": [ { "builtin_code": "CONV_2D" } ],
22 "subgraphs": [ {
23 "tensors": [
24 {
25 "shape": [ 1, 3, 3, 1 ],
26 "type": "UINT8",
27 "buffer": 0,
28 "name": "inputTensor",
29 "quantization": {
30 "min": [ 0.0 ],
31 "max": [ 255.0 ],
32 "scale": [ 1.0 ],
33 "zero_point": [ 0 ],
34 }
35 },
36 {
37 "shape": [ 1, 1, 1, 1 ],
38 "type": "UINT8",
39 "buffer": 1,
40 "name": "outputTensor",
41 "quantization": {
42 "min": [ 0.0 ],
43 "max": [ 511.0 ],
44 "scale": [ 2.0 ],
45 "zero_point": [ 0 ],
46 }
47 },
48 {
49 "shape": [ 1, 3, 3, 1 ],
50 "type": "UINT8",
51 "buffer": 2,
52 "name": "filterTensor",
53 "quantization": {
54 "min": [ 0.0 ],
55 "max": [ 255.0 ],
56 "scale": [ 1.0 ],
57 "zero_point": [ 0 ],
58 }
59 }
60 ],
61 "inputs": [ 0 ],
62 "outputs": [ 1 ],
63 "operators": [
64 {
65 "opcode_index": 0,
66 "inputs": [ 0, 2 ],
67 "outputs": [ 1 ],
68 "builtin_options_type": "Conv2DOptions",
69 "builtin_options": {
70 "padding": "VALID",
71 "stride_w": 1,
72 "stride_h": 1,
73 "fused_activation_function": "NONE"
74 },
75 "custom_options_format": "FLEXBUFFERS"
76 }
77 ],
78 } ],
79 "buffers" : [
80 { },
81 { },
82 { "data": [ 2,1,0, 6,2,1, 4,1,2 ], },
83 { },
84 ]
85 }
86 )";
87 ReadStringToBinary();
88 }
89
Kevin May7d96b162021-02-03 17:38:41 +000090 void CheckBufferContents(const TfLiteParserImpl::ModelPtr& model,
telsoa01c577f2c2018-08-31 09:22:23 +010091 std::vector<int32_t> bufferValues, size_t bufferIndex)
92 {
93 for(long unsigned int i=0; i<bufferValues.size(); i++)
94 {
Sadik Armagan1625efc2021-06-10 18:24:34 +010095 CHECK_EQ(TfLiteParserImpl::GetBuffer(model, bufferIndex)->data[i], bufferValues[i]);
telsoa01c577f2c2018-08-31 09:22:23 +010096 }
97 }
98};
99
Sadik Armagan1625efc2021-06-10 18:24:34 +0100100TEST_CASE_FIXTURE(GetBufferFixture, "GetBufferCheckContents")
telsoa01c577f2c2018-08-31 09:22:23 +0100101{
102 //Check contents of buffer are correct
Kevin May7d96b162021-02-03 17:38:41 +0000103 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
104 m_GraphBinary.size());
telsoa01c577f2c2018-08-31 09:22:23 +0100105 std::vector<int32_t> bufferValues = {2,1,0,6,2,1,4,1,2};
106 CheckBufferContents(model, bufferValues, 2);
107}
108
Sadik Armagan1625efc2021-06-10 18:24:34 +0100109TEST_CASE_FIXTURE(GetBufferFixture, "GetBufferCheckEmpty")
telsoa01c577f2c2018-08-31 09:22:23 +0100110{
111 //Check if test fixture buffers are empty or not
Kevin May7d96b162021-02-03 17:38:41 +0000112 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
113 m_GraphBinary.size());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100114 CHECK(TfLiteParserImpl::GetBuffer(model, 0)->data.empty());
115 CHECK(TfLiteParserImpl::GetBuffer(model, 1)->data.empty());
116 CHECK(!TfLiteParserImpl::GetBuffer(model, 2)->data.empty());
117 CHECK(TfLiteParserImpl::GetBuffer(model, 3)->data.empty());
telsoa01c577f2c2018-08-31 09:22:23 +0100118}
119
Sadik Armagan1625efc2021-06-10 18:24:34 +0100120TEST_CASE_FIXTURE(GetBufferFixture, "GetBufferCheckParseException")
telsoa01c577f2c2018-08-31 09:22:23 +0100121{
122 //Check if armnn::ParseException thrown when invalid buffer index used
Kevin May7d96b162021-02-03 17:38:41 +0000123 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
124 m_GraphBinary.size());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100125 CHECK_THROWS_AS(TfLiteParserImpl::GetBuffer(model, 4), armnn::Exception);
telsoa01c577f2c2018-08-31 09:22:23 +0100126}
127
Sadik Armagan1625efc2021-06-10 18:24:34 +0100128}