blob: 9541114d71d91a77938614b80e494cdd4ce73b42 [file] [log] [blame]
Ryan OShea86704732020-05-26 11:41:04 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersFixture.hpp"
8#include "../TfLiteParser.hpp"
9
10#include <string>
11#include <iostream>
12
13BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14
15struct SplitVFixture : public ParserFlatbuffersFixture
16{
17 explicit SplitVFixture(const std::string& inputShape,
18 const std::string& splitValues,
19 const std::string& sizeSplitsShape,
20 const std::string& axisShape,
21 const std::string& numSplits,
22 const std::string& outputShape1,
23 const std::string& outputShape2,
24 const std::string& axisData,
25 const std::string& dataType)
26 {
27 m_JsonString = R"(
28 {
29 "version": 3,
30 "operator_codes": [ { "builtin_code": "SPLIT_V" } ],
31 "subgraphs": [ {
32 "tensors": [
33 {
34 "shape": )" + inputShape + R"(,
35 "type": )" + dataType + R"(,
36 "buffer": 0,
37 "name": "inputTensor",
38 "quantization": {
39 "min": [ 0.0 ],
40 "max": [ 255.0 ],
41 "scale": [ 1.0 ],
42 "zero_point": [ 0 ],
43 }
44 },
45 {
46 "shape": )" + sizeSplitsShape + R"(,
47 "type": "INT32",
48 "buffer": 1,
49 "name": "sizeSplits",
50 "quantization": {
51 "min": [ 0.0 ],
52 "max": [ 255.0 ],
53 "scale": [ 1.0 ],
54 "zero_point": [ 0 ],
55 }
56 },
57 {
58 "shape": )" + axisShape + R"(,
59 "type": "INT32",
60 "buffer": 2,
61 "name": "axis",
62 "quantization": {
63 "min": [ 0.0 ],
64 "max": [ 255.0 ],
65 "scale": [ 1.0 ],
66 "zero_point": [ 0 ],
67 }
68 },
69 {
70 "shape": )" + outputShape1 + R"( ,
71 "type":)" + dataType + R"(,
72 "buffer": 3,
73 "name": "outputTensor1",
74 "quantization": {
75 "min": [ 0.0 ],
76 "max": [ 255.0 ],
77 "scale": [ 1.0 ],
78 "zero_point": [ 0 ],
79 }
80 },
81 {
82 "shape": )" + outputShape2 + R"( ,
83 "type":)" + dataType + R"(,
84 "buffer": 4,
85 "name": "outputTensor2",
86 "quantization": {
87 "min": [ 0.0 ],
88 "max": [ 255.0 ],
89 "scale": [ 1.0 ],
90 "zero_point": [ 0 ],
91 }
92 }
93 ],
94 "inputs": [ 0, 1, 2 ],
95 "outputs": [ 3, 4 ],
96 "operators": [
97 {
98 "opcode_index": 0,
99 "inputs": [ 0, 1, 2 ],
100 "outputs": [ 3, 4 ],
101 "builtin_options_type": "SplitVOptions",
102 "builtin_options": {
103 "num_splits": )" + numSplits + R"(
104 },
105 "custom_options_format": "FLEXBUFFERS"
106 }
107 ],
108 } ],
109 "buffers" : [ {}, { "data": )" + splitValues + R"( }, { "data": )" + axisData + R"( }, {}, {}]
110 }
111 )";
112
113 Setup();
114 }
115};
116
117/*
118 * Tested inferred splitSizes with splitValues [-1, 1] locally.
119 */
120
121struct SimpleSplitVAxisOneFixture : SplitVFixture
122{
123 SimpleSplitVAxisOneFixture()
Jan Eilersc0761e92020-06-29 16:48:44 +0100124 : SplitVFixture( "[ 4, 2, 2, 2 ]", "[ 1, 0, 0, 0, 3, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
Ryan OShea86704732020-05-26 11:41:04 +0100125 "[ 1, 2, 2, 2 ]", "[ 3, 2, 2, 2 ]", "[ 0, 0, 0, 0 ]", "FLOAT32")
126 {}
127};
128
129BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitVTwo, SimpleSplitVAxisOneFixture)
130{
131 RunTest<4, armnn::DataType::Float32>(
132 0,
133 { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
134 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
135 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
136 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
137 { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } },
138 {"outputTensor2", { 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
139 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
140 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } );
141}
142
143struct SimpleSplitVAxisTwoFixture : SplitVFixture
144{
145 SimpleSplitVAxisTwoFixture()
Jan Eilersc0761e92020-06-29 16:48:44 +0100146 : SplitVFixture( "[ 2, 4, 2, 2 ]", "[ 3, 0, 0, 0, 1, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
Ryan OShea86704732020-05-26 11:41:04 +0100147 "[ 2, 3, 2, 2 ]", "[ 2, 1, 2, 2 ]", "[ 1, 0, 0, 0 ]", "FLOAT32")
148 {}
149};
150
151BOOST_FIXTURE_TEST_CASE(ParseAxisTwoSplitVTwo, SimpleSplitVAxisTwoFixture)
152{
153 RunTest<4, armnn::DataType::Float32>(
154 0,
155 { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
156 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
157 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
158 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
159 { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
160 9.0f, 10.0f, 11.0f, 12.0f, 17.0f, 18.0f, 19.0f, 20.0f,
161 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f } },
162 {"outputTensor2", { 13.0f, 14.0f, 15.0f, 16.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } );
163}
164
165struct SimpleSplitVAxisThreeFixture : SplitVFixture
166{
167 SimpleSplitVAxisThreeFixture()
Jan Eilersc0761e92020-06-29 16:48:44 +0100168 : SplitVFixture( "[ 2, 2, 4, 2 ]", "[ 1, 0, 0, 0, 3, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
Ryan OShea86704732020-05-26 11:41:04 +0100169 "[ 2, 2, 1, 2 ]", "[ 2, 2, 3, 2 ]", "[ 2, 0, 0, 0 ]", "FLOAT32")
170 {}
171};
172
173BOOST_FIXTURE_TEST_CASE(ParseAxisThreeSplitVTwo, SimpleSplitVAxisThreeFixture)
174{
175 RunTest<4, armnn::DataType::Float32>(
176 0,
177 { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
178 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
179 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
180 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
181 { {"outputTensor1", { 1.0f, 2.0f, 9.0f, 10.0f, 17.0f, 18.0f, 25.0f, 26.0f } },
182 {"outputTensor2", { 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 11.0f, 12.0f,
183 13.0f, 14.0f, 15.0f, 16.0f, 19.0f, 20.0f, 21.0f, 22.0f,
184 23.0f, 24.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } );
185}
186
187struct SimpleSplitVAxisFourFixture : SplitVFixture
188{
189 SimpleSplitVAxisFourFixture()
Jan Eilersc0761e92020-06-29 16:48:44 +0100190 : SplitVFixture( "[ 2, 2, 2, 4 ]", "[ 3, 0, 0, 0, 1, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
Ryan OShea86704732020-05-26 11:41:04 +0100191 "[ 2, 2, 2, 3 ]", "[ 2, 2, 2, 1 ]", "[ 3, 0, 0, 0 ]", "FLOAT32")
192 {}
193};
194
195BOOST_FIXTURE_TEST_CASE(ParseAxisFourSplitVTwo, SimpleSplitVAxisFourFixture)
196{
197 RunTest<4, armnn::DataType::Float32>(
198 0,
199 { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
200 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
201 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
202 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
203 { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 5.0f, 6.0f, 7.0f, 9.0f, 10.0f,
204 11.0f, 13.0f, 14.0f, 15.0f, 17.0f, 18.0f, 19.0f, 21.0f,
205 22.0f, 23.0f, 25.0f, 26.0f, 27.0f, 29.0f, 30.0f, 31.0f} },
206 {"outputTensor2", { 4.0f, 8.0f, 12.0f, 16.0f, 20.0f, 24.0f, 28.0f, 32.0f } } } );
207}
208
209BOOST_AUTO_TEST_SUITE_END()