blob: 5ae194be277c9a08b61a0c6689227fcddd58c377 [file] [log] [blame]
Bruno Goncalves2d0eb862021-07-11 14:10:15 -03001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ParserFlatbuffersFixture.hpp"
7#include "../TfLiteParser.hpp"
8
9#include <string>
10
11TEST_SUITE("TensorflowLiteParser_Comparison")
12{
13struct ComparisonFixture : public ParserFlatbuffersFixture
14{
15 explicit ComparisonFixture(const std::string& operatorCode,
16 const std::string& dataType,
17 const std::string& inputShape,
18 const std::string& inputShape2,
19 const std::string& outputShape)
20 {
21 m_JsonString = R"(
22 {
23 "version": 3,
24 "operator_codes": [ { "builtin_code": )" + operatorCode + R"( } ],
25 "subgraphs": [ {
26 "tensors": [
27 {
28 "shape": )" + inputShape + R"(,
29 "type": )" + dataType + R"( ,
30 "buffer": 0,
31 "name": "inputTensor",
32 "quantization": {
33 "min": [ 0.0 ],
34 "max": [ 255.0 ],
35 "scale": [ 1.0 ],
36 "zero_point": [ 0 ],
37 }
38 },
39 {
40 "shape": )" + inputShape2 + R"(,
41 "type": )" + dataType + R"( ,
42 "buffer": 1,
43 "name": "inputTensor2",
44 "quantization": {
45 "min": [ 0.0 ],
46 "max": [ 255.0 ],
47 "scale": [ 1.0 ],
48 "zero_point": [ 0 ],
49 }
50 },
51 {
52 "shape": )" + outputShape + R"( ,
53 "type": "BOOL",
54 "buffer": 2,
55 "name": "outputTensor",
56 "quantization": {
57 "min": [ 0.0 ],
58 "max": [ 255.0 ],
59 "scale": [ 1.0 ],
60 "zero_point": [ 0 ],
61 }
62 }
63 ],
64 "inputs": [ 0, 1 ],
65 "outputs": [ 2 ],
66 "operators": [
67 {
68 "opcode_index": 0,
69 "inputs": [ 0, 1 ],
70 "outputs": [ 2 ],
71 "custom_options_format": "FLEXBUFFERS"
72 }
73 ],
74 } ],
75 "buffers" : [
76 { },
77 { }
78 ]
79 }
80 )";
81 Setup();
82 }
83};
84
85struct SimpleEqualFixture : public ComparisonFixture
86{
87 SimpleEqualFixture() : ComparisonFixture("EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
88};
89
90TEST_CASE_FIXTURE(SimpleEqualFixture, "SimpleEqual")
91{
92 RunTest<2, armnn::DataType::QAsymmU8,
93 armnn::DataType::Boolean>(
94 0,
95 {{"inputTensor", { 0, 1, 2, 3 }},
96 {"inputTensor2", { 0, 1, 5, 6 }}},
97 {{"outputTensor", { 1, 1, 0, 0 }}});
98}
99
100struct BroadcastEqualFixture : public ComparisonFixture
101{
102 BroadcastEqualFixture() : ComparisonFixture("EQUAL", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
103};
104
105TEST_CASE_FIXTURE(BroadcastEqualFixture, "BroadcastEqual")
106{
107 RunTest<2, armnn::DataType::QAsymmU8,
108 armnn::DataType::Boolean>(
109 0,
110 {{"inputTensor", { 0, 1, 2, 3 }},
111 {"inputTensor2", { 0, 1 }}},
112 {{"outputTensor", { 1, 1, 0, 0 }}});
113}
114
115struct SimpleNotEqualFixture : public ComparisonFixture
116{
117 SimpleNotEqualFixture() : ComparisonFixture("NOT_EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
118};
119
120TEST_CASE_FIXTURE(SimpleNotEqualFixture, "SimpleNotEqual")
121{
122 RunTest<2, armnn::DataType::QAsymmU8,
123 armnn::DataType::Boolean>(
124 0,
125 {{"inputTensor", { 0, 1, 2, 3 }},
126 {"inputTensor2", { 0, 1, 5, 6 }}},
127 {{"outputTensor", { 0, 0, 1, 1 }}});
128}
129
130struct BroadcastNotEqualFixture : public ComparisonFixture
131{
132 BroadcastNotEqualFixture() : ComparisonFixture("NOT_EQUAL", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
133};
134
135TEST_CASE_FIXTURE(BroadcastNotEqualFixture, "BroadcastNotEqual")
136{
137 RunTest<2, armnn::DataType::QAsymmU8,
138 armnn::DataType::Boolean>(
139 0,
140 {{"inputTensor", { 0, 1, 2, 3 }},
141 {"inputTensor2", { 0, 1 }}},
142 {{"outputTensor", { 0, 0, 1, 1 }}});
143}
144
145struct SimpleGreaterFixture : public ComparisonFixture
146{
147 SimpleGreaterFixture() : ComparisonFixture("GREATER", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
148};
149
150TEST_CASE_FIXTURE(SimpleGreaterFixture, "SimpleGreater")
151{
152 RunTest<2, armnn::DataType::QAsymmU8,
153 armnn::DataType::Boolean>(
154 0,
155 {{"inputTensor", { 0, 2, 3, 6 }},
156 {"inputTensor2", { 0, 1, 5, 3 }}},
157 {{"outputTensor", { 0, 1, 0, 1 }}});
158}
159
160struct BroadcastGreaterFixture : public ComparisonFixture
161{
162 BroadcastGreaterFixture() : ComparisonFixture("GREATER", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
163};
164
165TEST_CASE_FIXTURE(BroadcastGreaterFixture, "BroadcastGreater")
166{
167 RunTest<2, armnn::DataType::QAsymmU8,
168 armnn::DataType::Boolean>(
169 0,
170 {{"inputTensor", { 5, 4, 1, 0 }},
171 {"inputTensor2", { 2, 3 }}},
172 {{"outputTensor", { 1, 1, 0, 0 }}});
173}
174
175struct SimpleGreaterOrEqualFixture : public ComparisonFixture
176{
177 SimpleGreaterOrEqualFixture() : ComparisonFixture("GREATER_EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
178};
179
180TEST_CASE_FIXTURE(SimpleGreaterOrEqualFixture, "SimpleGreaterOrEqual")
181{
182 RunTest<2, armnn::DataType::QAsymmU8,
183 armnn::DataType::Boolean>(
184 0,
185 {{"inputTensor", { 0, 2, 3, 6 }},
186 {"inputTensor2", { 0, 1, 5, 3 }}},
187 {{"outputTensor", { 1, 1, 0, 1 }}});
188}
189
190struct BroadcastGreaterOrEqualFixture : public ComparisonFixture
191{
192 BroadcastGreaterOrEqualFixture() : ComparisonFixture("GREATER_EQUAL", "UINT8",
193 "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
194};
195
196TEST_CASE_FIXTURE(BroadcastGreaterOrEqualFixture, "BroadcastGreaterOrEqual")
197{
198 RunTest<2, armnn::DataType::QAsymmU8,
199 armnn::DataType::Boolean>(
200 0,
201 {{"inputTensor", { 5, 4, 1, 0 }},
202 {"inputTensor2", { 2, 4 }}},
203 {{"outputTensor", { 1, 1, 0, 0 }}});
204}
205
206struct SimpleLessFixture : public ComparisonFixture
207{
208 SimpleLessFixture() : ComparisonFixture("LESS", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
209};
210
211TEST_CASE_FIXTURE(SimpleLessFixture, "SimpleLess")
212{
213 RunTest<2, armnn::DataType::QAsymmU8,
214 armnn::DataType::Boolean>(
215 0,
216 {{"inputTensor", { 0, 2, 3, 6 }},
217 {"inputTensor2", { 0, 1, 5, 3 }}},
218 {{"outputTensor", { 0, 0, 1, 0 }}});
219}
220
221struct BroadcastLessFixture : public ComparisonFixture
222{
223 BroadcastLessFixture() : ComparisonFixture("LESS", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
224};
225
226TEST_CASE_FIXTURE(BroadcastLessFixture, "BroadcastLess")
227{
228 RunTest<2, armnn::DataType::QAsymmU8,
229 armnn::DataType::Boolean>(
230 0,
231 {{"inputTensor", { 5, 4, 1, 0 }},
232 {"inputTensor2", { 2, 3 }}},
233 {{"outputTensor", { 0, 0, 1, 1 }}});
234}
235
236struct SimpleLessOrEqualFixture : public ComparisonFixture
237{
238 SimpleLessOrEqualFixture() : ComparisonFixture("LESS_EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
239};
240
241TEST_CASE_FIXTURE(SimpleLessOrEqualFixture, "SimpleLessOrEqual")
242{
243 RunTest<2, armnn::DataType::QAsymmU8,
244 armnn::DataType::Boolean>(
245 0,
246 {{"inputTensor", { 0, 2, 3, 6 }},
247 {"inputTensor2", { 0, 1, 5, 3 }}},
248 {{"outputTensor", { 1, 0, 1, 0 }}});
249}
250
251struct BroadcastLessOrEqualFixture : public ComparisonFixture
252{
253 BroadcastLessOrEqualFixture() : ComparisonFixture("LESS_EQUAL", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
254};
255
256TEST_CASE_FIXTURE(BroadcastLessOrEqualFixture, "BroadcastLessOrEqual")
257{
258 RunTest<2, armnn::DataType::QAsymmU8,
259 armnn::DataType::Boolean>(
260 0,
261 {{"inputTensor", { 5, 4, 1, 0 }},
262 {"inputTensor2", { 1, 3 }}},
263 {{"outputTensor", { 0, 0, 1, 1 }}});
264}
265
266}