blob: 9a2fabf5bdc338d17f7499fae8ef5f50f19b70c1 [file] [log] [blame]
Aron Virginas-Tar422f2fb2019-10-21 14:09:11 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ParserFlatbuffersSerializeFixture.hpp"
7#include "../Deserializer.hpp"
8
9#include <ResolveType.hpp>
10
11#include <backendsCommon/test/QuantizeHelper.hpp>
12
13#include <boost/test/unit_test.hpp>
14
15#include <string>
16
17BOOST_AUTO_TEST_SUITE(Deserializer)
18
19#define DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
20struct Simple##operation##dataType##Fixture : public SimpleComparisonFixture \
21{ \
22 Simple##operation##dataType##Fixture() \
23 : SimpleComparisonFixture(#dataType, #operation) {} \
24};
25
26#define DECLARE_SIMPLE_COMPARISON_TEST_CASE(operation, dataType) \
27DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
28BOOST_FIXTURE_TEST_CASE(operation##dataType, Simple##operation##dataType##Fixture) \
29{ \
30 using T = armnn::ResolveType<armnn::DataType::dataType>; \
31 constexpr float qScale = 1.f; \
32 constexpr int32_t qOffset = 0; \
33 RunTest<4, armnn::DataType::dataType, armnn::DataType::Boolean>( \
34 0, \
35 {{ "InputLayer0", QuantizedVector<T>(qScale, qOffset, s_TestData.m_InputData0) }, \
36 { "InputLayer1", QuantizedVector<T>(qScale, qOffset, s_TestData.m_InputData1) }}, \
37 {{ "OutputLayer", s_TestData.m_Output##operation }}); \
38}
39
40struct ComparisonFixture : public ParserFlatbuffersSerializeFixture
41{
42 explicit ComparisonFixture(const std::string& inputShape0,
43 const std::string& inputShape1,
44 const std::string& outputShape,
45 const std::string& inputDataType,
46 const std::string& comparisonOperation)
47 {
48 m_JsonString = R"(
49 {
50 inputIds: [0, 1],
51 outputIds: [3],
52 layers: [
53 {
54 layer_type: "InputLayer",
55 layer: {
56 base: {
57 layerBindingId: 0,
58 base: {
59 index: 0,
60 layerName: "InputLayer0",
61 layerType: "Input",
62 inputSlots: [{
63 index: 0,
64 connection: { sourceLayerIndex:0, outputSlotIndex:0 },
65 }],
66 outputSlots: [{
67 index: 0,
68 tensorInfo: {
69 dimensions: )" + inputShape0 + R"(,
70 dataType: )" + inputDataType + R"(
71 },
72 }],
73 },
74 }
75 },
76 },
77 {
78 layer_type: "InputLayer",
79 layer: {
80 base: {
81 layerBindingId: 1,
82 base: {
83 index:1,
84 layerName: "InputLayer1",
85 layerType: "Input",
86 inputSlots: [{
87 index: 0,
88 connection: { sourceLayerIndex:0, outputSlotIndex:0 },
89 }],
90 outputSlots: [{
91 index: 0,
92 tensorInfo: {
93 dimensions: )" + inputShape1 + R"(,
94 dataType: )" + inputDataType + R"(
95 },
96 }],
97 },
98 }
99 },
100 },
101 {
102 layer_type: "ComparisonLayer",
103 layer: {
104 base: {
105 index:2,
106 layerName: "ComparisonLayer",
107 layerType: "Comparison",
108 inputSlots: [{
109 index: 0,
110 connection: { sourceLayerIndex:0, outputSlotIndex:0 },
111 },
112 {
113 index: 1,
114 connection: { sourceLayerIndex:1, outputSlotIndex:0 },
115 }],
116 outputSlots: [{
117 index: 0,
118 tensorInfo: {
119 dimensions: )" + outputShape + R"(,
120 dataType: Boolean
121 },
122 }],
123 },
124 descriptor: {
125 operation: )" + comparisonOperation + R"(
126 }
127 },
128 },
129 {
130 layer_type: "OutputLayer",
131 layer: {
132 base:{
133 layerBindingId: 0,
134 base: {
135 index: 3,
136 layerName: "OutputLayer",
137 layerType: "Output",
138 inputSlots: [{
139 index: 0,
140 connection: { sourceLayerIndex:2, outputSlotIndex:0 },
141 }],
142 outputSlots: [{
143 index: 0,
144 tensorInfo: {
145 dimensions: )" + outputShape + R"(,
146 dataType: Boolean
147 },
148 }],
149 }
150 }
151 },
152 }
153 ]
154 }
155 )";
156 Setup();
157 }
158};
159
160struct SimpleComparisonTestData
161{
162 SimpleComparisonTestData()
163 {
164 m_InputData0 =
165 {
166 1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
167 3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f
168 };
169
170 m_InputData1 =
171 {
172 1.f, 1.f, 1.f, 1.f, 3.f, 3.f, 3.f, 3.f,
173 5.f, 5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 4.f
174 };
175
176 m_OutputEqual =
177 {
178 1, 1, 1, 1, 0, 0, 0, 0,
179 0, 0, 0, 0, 1, 1, 1, 1
180 };
181
182 m_OutputGreater =
183 {
184 0, 0, 0, 0, 1, 1, 1, 1,
185 0, 0, 0, 0, 0, 0, 0, 0
186 };
187
188 m_OutputGreaterOrEqual =
189 {
190 1, 1, 1, 1, 1, 1, 1, 1,
191 0, 0, 0, 0, 1, 1, 1, 1
192 };
193
194 m_OutputLess =
195 {
196 0, 0, 0, 0, 0, 0, 0, 0,
197 1, 1, 1, 1, 0, 0, 0, 0
198 };
199
200 m_OutputLessOrEqual =
201 {
202 1, 1, 1, 1, 0, 0, 0, 0,
203 1, 1, 1, 1, 1, 1, 1, 1
204 };
205
206 m_OutputNotEqual =
207 {
208 0, 0, 0, 0, 1, 1, 1, 1,
209 1, 1, 1, 1, 0, 0, 0, 0
210 };
211 }
212
213 std::vector<float> m_InputData0;
214 std::vector<float> m_InputData1;
215
216 std::vector<uint8_t> m_OutputEqual;
217 std::vector<uint8_t> m_OutputGreater;
218 std::vector<uint8_t> m_OutputGreaterOrEqual;
219 std::vector<uint8_t> m_OutputLess;
220 std::vector<uint8_t> m_OutputLessOrEqual;
221 std::vector<uint8_t> m_OutputNotEqual;
222};
223
224struct SimpleComparisonFixture : public ComparisonFixture
225{
226 SimpleComparisonFixture(const std::string& inputDataType,
227 const std::string& comparisonOperation)
228 : ComparisonFixture("[ 2, 2, 2, 2 ]", // inputShape0
229 "[ 2, 2, 2, 2 ]", // inputShape1
230 "[ 2, 2, 2, 2 ]", // outputShape,
231 inputDataType,
232 comparisonOperation) {}
233
234 static SimpleComparisonTestData s_TestData;
235};
236
237SimpleComparisonTestData SimpleComparisonFixture::s_TestData;
238
239DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal, Float32)
240DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater, Float32)
241DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, Float32)
242DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less, Float32)
243DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual, Float32)
244DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual, Float32)
245
246DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal, QuantisedAsymm8)
247DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater, QuantisedAsymm8)
248DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, QuantisedAsymm8)
249DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less, QuantisedAsymm8)
250DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual, QuantisedAsymm8)
251DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual, QuantisedAsymm8)
252
253BOOST_AUTO_TEST_SUITE_END()