blob: 011312f7c9ee602a27bc211c9d0cb9474097b22f [file] [log] [blame]
Matthew Jacksonbcca1f42019-07-16 11:39:21 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersFixture.hpp"
8#include "../TfLiteParser.hpp"
9
10#include <string>
11#include <iostream>
12
13BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14
15struct PackFixture : public ParserFlatbuffersFixture
16{
17 explicit PackFixture(const std::string & inputShape,
18 const unsigned int numInputs,
19 const std::string & outputShape,
20 const std::string & axis)
21 {
22 m_JsonString = R"(
23 {
24 "version": 3,
25 "operator_codes": [ { "builtin_code": "PACK" } ],
26 "subgraphs": [ {
27 "tensors": [)";
28
29 for (unsigned int i = 0; i < numInputs; ++i)
30 {
31 m_JsonString += R"(
32 {
33 "shape": )" + inputShape + R"(,
34 "type": "FLOAT32",
35 "buffer": )" + std::to_string(i) + R"(,
36 "name": "inputTensor)" + std::to_string(i + 1) + R"(",
37 "quantization": {
38 "min": [ 0.0 ],
39 "max": [ 255.0 ],
40 "scale": [ 1.0 ],
41 "zero_point": [ 0 ],
42 }
43 },)";
44 }
45
46 std::string inputIndexes;
47 for (unsigned int i = 0; i < numInputs-1; ++i)
48 {
49 inputIndexes += std::to_string(i) + R"(, )";
50 }
51 inputIndexes += std::to_string(numInputs-1);
52
53 m_JsonString += R"(
54 {
55 "shape": )" + outputShape + R"( ,
56 "type": "FLOAT32",
57 "buffer": )" + std::to_string(numInputs) + R"(,
58 "name": "outputTensor",
59 "quantization": {
60 "min": [ 0.0 ],
61 "max": [ 255.0 ],
62 "scale": [ 1.0 ],
63 "zero_point": [ 0 ],
64 }
65 }
66 ],
67 "inputs": [ )" + inputIndexes + R"( ],
68 "outputs": [ 2 ],
69 "operators": [
70 {
71 "opcode_index": 0,
72 "inputs": [ )" + inputIndexes + R"( ],
73 "outputs": [ 2 ],
74 "builtin_options_type": "PackOptions",
75 "builtin_options": {
76 "axis": )" + axis + R"(,
77 "values_count": )" + std::to_string(numInputs) + R"(
78 },
79 "custom_options_format": "FLEXBUFFERS"
80 }
81 ],
82 } ],
83 "buffers" : [)";
84
85 for (unsigned int i = 0; i < numInputs-1; ++i)
86 {
87 m_JsonString += R"(
88 { },)";
89 }
90 m_JsonString += R"(
91 { }
92 ]
93 })";
94 Setup();
95 }
96};
97
98struct SimplePackFixture : PackFixture
99{
100 SimplePackFixture() : PackFixture("[ 3, 2, 3 ]",
101 2,
102 "[ 3, 2, 3, 2 ]",
103 "3") {}
104};
105
106BOOST_FIXTURE_TEST_CASE(ParsePack, SimplePackFixture)
107{
108 RunTest<4, armnn::DataType::Float32>(
109 0,
110 { {"inputTensor1", { 1, 2, 3,
111 4, 5, 6,
112
113 7, 8, 9,
114 10, 11, 12,
115
116 13, 14, 15,
117 16, 17, 18 } },
118 {"inputTensor2", { 19, 20, 21,
119 22, 23, 24,
120
121 25, 26, 27,
122 28, 29, 30,
123
124 31, 32, 33,
125 34, 35, 36 } } },
126 { {"outputTensor", { 1, 19,
127 2, 20,
128 3, 21,
129
130 4, 22,
131 5, 23,
132 6, 24,
133
134
135 7, 25,
136 8, 26,
137 9, 27,
138
139 10, 28,
140 11, 29,
141 12, 30,
142
143
144 13, 31,
145 14, 32,
146 15, 33,
147
148 16, 34,
149 17, 35,
150 18, 36 } } });
151}
152
153BOOST_AUTO_TEST_SUITE_END()