blob: 0f60250377a17825960d9525dcb0f4d1a2c6b96a [file] [log] [blame]
Aron Virginas-Tar422f2fb2019-10-21 14:09:11 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ParserFlatbuffersSerializeFixture.hpp"
Finn Williams85d36712021-01-26 22:30:06 +00007#include <armnnDeserializer/IDeserializer.hpp>
Aron Virginas-Tar422f2fb2019-10-21 14:09:11 +01008
Colm Donelanc42a9872022-02-02 16:35:09 +00009#include <armnnUtils/QuantizeHelper.hpp>
Aron Virginas-Tar422f2fb2019-10-21 14:09:11 +010010#include <ResolveType.hpp>
11
Aron Virginas-Tar422f2fb2019-10-21 14:09:11 +010012#include <string>
13
Sadik Armagan1625efc2021-06-10 18:24:34 +010014TEST_SUITE("Deserializer_Comparison")
15{
Aron Virginas-Tar422f2fb2019-10-21 14:09:11 +010016#define DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
17struct Simple##operation##dataType##Fixture : public SimpleComparisonFixture \
18{ \
19 Simple##operation##dataType##Fixture() \
20 : SimpleComparisonFixture(#dataType, #operation) {} \
21};
22
23#define DECLARE_SIMPLE_COMPARISON_TEST_CASE(operation, dataType) \
24DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
Sadik Armagan1625efc2021-06-10 18:24:34 +010025TEST_CASE_FIXTURE(Simple##operation##dataType##Fixture, #operation#dataType) \
Aron Virginas-Tar422f2fb2019-10-21 14:09:11 +010026{ \
27 using T = armnn::ResolveType<armnn::DataType::dataType>; \
28 constexpr float qScale = 1.f; \
29 constexpr int32_t qOffset = 0; \
30 RunTest<4, armnn::DataType::dataType, armnn::DataType::Boolean>( \
31 0, \
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010032 {{ "InputLayer0", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData0, qScale, qOffset) }, \
33 { "InputLayer1", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData1, qScale, qOffset) }}, \
Aron Virginas-Tar422f2fb2019-10-21 14:09:11 +010034 {{ "OutputLayer", s_TestData.m_Output##operation }}); \
35}
36
37struct ComparisonFixture : public ParserFlatbuffersSerializeFixture
38{
39 explicit ComparisonFixture(const std::string& inputShape0,
40 const std::string& inputShape1,
41 const std::string& outputShape,
42 const std::string& inputDataType,
43 const std::string& comparisonOperation)
44 {
45 m_JsonString = R"(
46 {
47 inputIds: [0, 1],
48 outputIds: [3],
49 layers: [
50 {
51 layer_type: "InputLayer",
52 layer: {
53 base: {
54 layerBindingId: 0,
55 base: {
56 index: 0,
57 layerName: "InputLayer0",
58 layerType: "Input",
59 inputSlots: [{
60 index: 0,
61 connection: { sourceLayerIndex:0, outputSlotIndex:0 },
62 }],
63 outputSlots: [{
64 index: 0,
65 tensorInfo: {
66 dimensions: )" + inputShape0 + R"(,
67 dataType: )" + inputDataType + R"(
68 },
69 }],
70 },
71 }
72 },
73 },
74 {
75 layer_type: "InputLayer",
76 layer: {
77 base: {
78 layerBindingId: 1,
79 base: {
80 index:1,
81 layerName: "InputLayer1",
82 layerType: "Input",
83 inputSlots: [{
84 index: 0,
85 connection: { sourceLayerIndex:0, outputSlotIndex:0 },
86 }],
87 outputSlots: [{
88 index: 0,
89 tensorInfo: {
90 dimensions: )" + inputShape1 + R"(,
91 dataType: )" + inputDataType + R"(
92 },
93 }],
94 },
95 }
96 },
97 },
98 {
99 layer_type: "ComparisonLayer",
100 layer: {
101 base: {
102 index:2,
103 layerName: "ComparisonLayer",
104 layerType: "Comparison",
105 inputSlots: [{
106 index: 0,
107 connection: { sourceLayerIndex:0, outputSlotIndex:0 },
108 },
109 {
110 index: 1,
111 connection: { sourceLayerIndex:1, outputSlotIndex:0 },
112 }],
113 outputSlots: [{
114 index: 0,
115 tensorInfo: {
116 dimensions: )" + outputShape + R"(,
117 dataType: Boolean
118 },
119 }],
120 },
121 descriptor: {
122 operation: )" + comparisonOperation + R"(
123 }
124 },
125 },
126 {
127 layer_type: "OutputLayer",
128 layer: {
129 base:{
130 layerBindingId: 0,
131 base: {
132 index: 3,
133 layerName: "OutputLayer",
134 layerType: "Output",
135 inputSlots: [{
136 index: 0,
137 connection: { sourceLayerIndex:2, outputSlotIndex:0 },
138 }],
139 outputSlots: [{
140 index: 0,
141 tensorInfo: {
142 dimensions: )" + outputShape + R"(,
143 dataType: Boolean
144 },
145 }],
146 }
147 }
148 },
149 }
150 ]
151 }
152 )";
153 Setup();
154 }
155};
156
157struct SimpleComparisonTestData
158{
159 SimpleComparisonTestData()
160 {
161 m_InputData0 =
162 {
163 1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
164 3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f
165 };
166
167 m_InputData1 =
168 {
169 1.f, 1.f, 1.f, 1.f, 3.f, 3.f, 3.f, 3.f,
170 5.f, 5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 4.f
171 };
172
173 m_OutputEqual =
174 {
175 1, 1, 1, 1, 0, 0, 0, 0,
176 0, 0, 0, 0, 1, 1, 1, 1
177 };
178
179 m_OutputGreater =
180 {
181 0, 0, 0, 0, 1, 1, 1, 1,
182 0, 0, 0, 0, 0, 0, 0, 0
183 };
184
185 m_OutputGreaterOrEqual =
186 {
187 1, 1, 1, 1, 1, 1, 1, 1,
188 0, 0, 0, 0, 1, 1, 1, 1
189 };
190
191 m_OutputLess =
192 {
193 0, 0, 0, 0, 0, 0, 0, 0,
194 1, 1, 1, 1, 0, 0, 0, 0
195 };
196
197 m_OutputLessOrEqual =
198 {
199 1, 1, 1, 1, 0, 0, 0, 0,
200 1, 1, 1, 1, 1, 1, 1, 1
201 };
202
203 m_OutputNotEqual =
204 {
205 0, 0, 0, 0, 1, 1, 1, 1,
206 1, 1, 1, 1, 0, 0, 0, 0
207 };
208 }
209
210 std::vector<float> m_InputData0;
211 std::vector<float> m_InputData1;
212
213 std::vector<uint8_t> m_OutputEqual;
214 std::vector<uint8_t> m_OutputGreater;
215 std::vector<uint8_t> m_OutputGreaterOrEqual;
216 std::vector<uint8_t> m_OutputLess;
217 std::vector<uint8_t> m_OutputLessOrEqual;
218 std::vector<uint8_t> m_OutputNotEqual;
219};
220
221struct SimpleComparisonFixture : public ComparisonFixture
222{
223 SimpleComparisonFixture(const std::string& inputDataType,
224 const std::string& comparisonOperation)
225 : ComparisonFixture("[ 2, 2, 2, 2 ]", // inputShape0
226 "[ 2, 2, 2, 2 ]", // inputShape1
227 "[ 2, 2, 2, 2 ]", // outputShape,
228 inputDataType,
229 comparisonOperation) {}
230
231 static SimpleComparisonTestData s_TestData;
232};
233
234SimpleComparisonTestData SimpleComparisonFixture::s_TestData;
235
236DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal, Float32)
237DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater, Float32)
238DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, Float32)
239DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less, Float32)
240DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual, Float32)
241DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual, Float32)
242
Derek Lambertif90c56d2020-01-10 17:14:08 +0000243
Derek Lambertif90c56d2020-01-10 17:14:08 +0000244DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal, QAsymmU8)
245DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater, QAsymmU8)
246DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, QAsymmU8)
247DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less, QAsymmU8)
248DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual, QAsymmU8)
249DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual, QAsymmU8)
Aron Virginas-Tar422f2fb2019-10-21 14:09:11 +0100250
Sadik Armagan1625efc2021-06-10 18:24:34 +0100251}