blob: a6b703b08b92dc8f22a0fbab5a524ed3ea618b52 [file] [log] [blame]
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ComparisonTestImpl.hpp"
7
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01008#include <armnn/utility/Assert.hpp>
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +01009#include <Half.hpp>
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010010#include <QuantizeHelper.hpp>
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +010011#include <ResolveType.hpp>
12
13#include <backendsCommon/Workload.hpp>
14#include <backendsCommon/WorkloadData.hpp>
15
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +010016#include <backendsCommon/test/TensorCopyUtils.hpp>
17#include <backendsCommon/test/WorkloadTestUtils.hpp>
18
19#include <test/TensorHelpers.hpp>
20
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +010021namespace
22{
23
24template <std::size_t NumDims,
25 armnn::DataType ArmnnInType,
26 typename InType = armnn::ResolveType<ArmnnInType>>
27LayerTestResult<uint8_t, NumDims> ComparisonTestImpl(
28 armnn::IWorkloadFactory & workloadFactory,
29 const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,
30 const armnn::ComparisonDescriptor& descriptor,
31 const armnn::TensorShape& shape0,
32 std::vector<InType> values0,
33 float quantScale0,
34 int quantOffset0,
35 const armnn::TensorShape& shape1,
36 std::vector<InType> values1,
37 float quantScale1,
38 int quantOffset1,
39 const armnn::TensorShape& outShape,
40 std::vector<uint8_t> outValues,
41 float outQuantScale,
42 int outQuantOffset)
43{
Jan Eilers8eb25602020-03-09 12:13:48 +000044 IgnoreUnused(memoryManager);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010045 ARMNN_ASSERT(shape0.GetNumDimensions() == NumDims);
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +010046 armnn::TensorInfo inputTensorInfo0(shape0, ArmnnInType, quantScale0, quantOffset0);
47
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010048 ARMNN_ASSERT(shape1.GetNumDimensions() == NumDims);
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +010049 armnn::TensorInfo inputTensorInfo1(shape1, ArmnnInType, quantScale1, quantOffset1);
50
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010051 ARMNN_ASSERT(outShape.GetNumDimensions() == NumDims);
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +010052 armnn::TensorInfo outputTensorInfo(outShape, armnn::DataType::Boolean, outQuantScale, outQuantOffset);
53
54 auto input0 = MakeTensor<InType, NumDims>(inputTensorInfo0, values0);
55 auto input1 = MakeTensor<InType, NumDims>(inputTensorInfo1, values1);
56
57 LayerTestResult<uint8_t, NumDims> ret(outputTensorInfo);
58
59 std::unique_ptr<armnn::ITensorHandle> inputHandle0 = workloadFactory.CreateTensorHandle(inputTensorInfo0);
60 std::unique_ptr<armnn::ITensorHandle> inputHandle1 = workloadFactory.CreateTensorHandle(inputTensorInfo1);
61 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
62
63 armnn::ComparisonQueueDescriptor qDescriptor;
64 qDescriptor.m_Parameters = descriptor;
65
66 armnn::WorkloadInfo info;
67 AddInputToWorkload(qDescriptor, info, inputTensorInfo0, inputHandle0.get());
68 AddInputToWorkload(qDescriptor, info, inputTensorInfo1, inputHandle1.get());
69 AddOutputToWorkload(qDescriptor, info, outputTensorInfo, outputHandle.get());
70
71 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateComparison(qDescriptor, info);
72
73 inputHandle0->Allocate();
74 inputHandle1->Allocate();
75 outputHandle->Allocate();
76
77 CopyDataToITensorHandle(inputHandle0.get(), input0.origin());
78 CopyDataToITensorHandle(inputHandle1.get(), input1.origin());
79
80 workload->PostAllocationConfigure();
81 ExecuteWorkload(*workload, memoryManager);
82
83 CopyDataFromITensorHandle(ret.output.origin(), outputHandle.get());
84
85 ret.outputExpected = MakeTensor<uint8_t, NumDims>(outputTensorInfo, outValues);
86 ret.compareBoolean = true;
87
88 return ret;
89}
90
91template <std::size_t NumDims,
92 armnn::DataType ArmnnInType,
93 typename InType = armnn::ResolveType<ArmnnInType>>
94LayerTestResult<uint8_t, NumDims> ComparisonTestImpl(
95 armnn::IWorkloadFactory & workloadFactory,
96 const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,
97 const armnn::ComparisonDescriptor& descriptor,
98 const armnn::TensorShape& shape0,
99 std::vector<InType> values0,
100 const armnn::TensorShape& shape1,
101 std::vector<InType> values1,
102 const armnn::TensorShape outShape,
103 std::vector<uint8_t> outValues,
104 float quantScale = 1.f,
105 int quantOffset = 0)
106{
107 return ComparisonTestImpl<NumDims, ArmnnInType>(
108 workloadFactory,
109 memoryManager,
110 descriptor,
111 shape0,
112 values0,
113 quantScale,
114 quantOffset,
115 shape1,
116 values1,
117 quantScale,
118 quantOffset,
119 outShape,
120 outValues,
121 quantScale,
122 quantOffset);
123}
124
125template<typename TestData>
126std::vector<uint8_t> GetExpectedOutputData(const TestData& testData, armnn::ComparisonOperation operation)
127{
128 switch (operation)
129 {
130 case armnn::ComparisonOperation::Equal:
131 return testData.m_OutputEqual;
132 case armnn::ComparisonOperation::Greater:
133 return testData.m_OutputGreater;
134 case armnn::ComparisonOperation::GreaterOrEqual:
135 return testData.m_OutputGreaterOrEqual;
136 case armnn::ComparisonOperation::Less:
137 return testData.m_OutputLess;
138 case armnn::ComparisonOperation::LessOrEqual:
139 return testData.m_OutputLessOrEqual;
140 case armnn::ComparisonOperation::NotEqual:
141 default:
142 return testData.m_OutputNotEqual;
143 }
144}
145
146template<armnn::DataType ArmnnInType, typename TestData>
147LayerTestResult<uint8_t, 4> ComparisonTestImpl(armnn::IWorkloadFactory& workloadFactory,
148 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
149 const TestData& testData,
150 armnn::ComparisonOperation operation,
151 float quantScale = 1.f,
152 int quantOffset = 0)
153{
154 using T = armnn::ResolveType<ArmnnInType>;
155
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100156 std::vector<T> inputData0 = armnnUtils::QuantizedVector<T>(testData.m_InputData0, quantScale, quantOffset);
157 std::vector<T> inputData1 = armnnUtils::QuantizedVector<T>(testData.m_InputData1, quantScale, quantOffset);
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100158
159 return ComparisonTestImpl<4, ArmnnInType>(
160 workloadFactory,
161 memoryManager,
162 armnn::ComparisonDescriptor(operation),
163 testData.m_InputShape0,
164 inputData0,
165 testData.m_InputShape1,
166 inputData1,
167 testData.m_OutputShape,
168 GetExpectedOutputData(testData, operation),
169 quantScale,
170 quantOffset);
171}
172
173class ComparisonTestData
174{
175public:
176 ComparisonTestData() = default;
177 virtual ~ComparisonTestData() = default;
178
179 armnn::TensorShape m_InputShape0;
180 armnn::TensorShape m_InputShape1;
181 armnn::TensorShape m_OutputShape;
182
183 std::vector<float> m_InputData0;
184 std::vector<float> m_InputData1;
185
186 std::vector<uint8_t> m_OutputEqual;
187 std::vector<uint8_t> m_OutputGreater;
188 std::vector<uint8_t> m_OutputGreaterOrEqual;
189 std::vector<uint8_t> m_OutputLess;
190 std::vector<uint8_t> m_OutputLessOrEqual;
191 std::vector<uint8_t> m_OutputNotEqual;
192};
193
194class SimpleTestData : public ComparisonTestData
195{
196public:
197 SimpleTestData() : ComparisonTestData()
198 {
199 m_InputShape0 = { 2, 2, 2, 2 };
200
201 m_InputShape1 = m_InputShape0;
202 m_OutputShape = m_InputShape0;
203
204 m_InputData0 =
205 {
206 1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
207 3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f
208 };
209
210 m_InputData1 =
211 {
212 1.f, 1.f, 1.f, 1.f, 3.f, 3.f, 3.f, 3.f,
213 5.f, 5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 4.f
214 };
215
216 m_OutputEqual =
217 {
218 1, 1, 1, 1, 0, 0, 0, 0,
219 0, 0, 0, 0, 1, 1, 1, 1
220 };
221
222 m_OutputGreater =
223 {
224 0, 0, 0, 0, 1, 1, 1, 1,
225 0, 0, 0, 0, 0, 0, 0, 0
226 };
227
228 m_OutputGreaterOrEqual =
229 {
230 1, 1, 1, 1, 1, 1, 1, 1,
231 0, 0, 0, 0, 1, 1, 1, 1
232 };
233
234 m_OutputLess =
235 {
236 0, 0, 0, 0, 0, 0, 0, 0,
237 1, 1, 1, 1, 0, 0, 0, 0
238 };
239
240 m_OutputLessOrEqual =
241 {
242 1, 1, 1, 1, 0, 0, 0, 0,
243 1, 1, 1, 1, 1, 1, 1, 1
244 };
245
246 m_OutputNotEqual =
247 {
248 0, 0, 0, 0, 1, 1, 1, 1,
249 1, 1, 1, 1, 0, 0, 0, 0
250 };
251 }
252};
253
254class Broadcast1ElementTestData : public ComparisonTestData
255{
256public:
257 Broadcast1ElementTestData() : ComparisonTestData()
258 {
259 m_InputShape0 = { 1, 2, 2, 2 };
260 m_InputShape1 = { 1, 1, 1, 1 };
261
262 m_OutputShape = m_InputShape0;
263
264 m_InputData0 = { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f };
265 m_InputData1 = { 3.f };
266
267 m_OutputEqual = { 0, 0, 1, 0, 0, 0, 0, 0 };
268 m_OutputGreater = { 0, 0, 0, 1, 1, 1, 1, 1 };
269 m_OutputGreaterOrEqual = { 0, 0, 1, 1, 1, 1, 1, 1 };
270 m_OutputLess = { 1, 1, 0, 0, 0, 0, 0, 0 };
271 m_OutputLessOrEqual = { 1, 1, 1, 0, 0, 0, 0, 0 };
272 m_OutputNotEqual = { 1, 1, 0, 1, 1, 1, 1, 1 };
273 }
274};
275
276class Broadcast1dVectorTestData : public ComparisonTestData
277{
278public:
279 Broadcast1dVectorTestData() : ComparisonTestData()
280 {
281 m_InputShape0 = { 1, 2, 2, 3 };
282 m_InputShape1 = { 1, 1, 1, 3 };
283
284 m_OutputShape = m_InputShape0;
285
286 m_InputData0 =
287 {
288 1.f, 2.f, 3.f, 4.f, 5.f, 6.f,
289 7.f, 8.f, 9.f, 10.f, 11.f, 12.f
290 };
291
292 m_InputData1 = { 4.f, 5.f, 6.f };
293
294 m_OutputEqual =
295 {
296 0, 0, 0, 1, 1, 1,
297 0, 0, 0, 0, 0, 0
298 };
299
300 m_OutputGreater =
301 {
302 0, 0, 0, 0, 0, 0,
303 1, 1, 1, 1, 1, 1
304 };
305
306 m_OutputGreaterOrEqual =
307 {
308 0, 0, 0, 1, 1, 1,
309 1, 1, 1, 1, 1, 1
310 };
311
312 m_OutputLess =
313 {
314 1, 1, 1, 0, 0, 0,
315 0, 0, 0, 0, 0, 0
316 };
317
318 m_OutputLessOrEqual =
319 {
320 1, 1, 1, 1, 1, 1,
321 0, 0, 0, 0, 0, 0
322 };
323
324 m_OutputNotEqual =
325 {
326 1, 1, 1, 0, 0, 0,
327 1, 1, 1, 1, 1, 1
328 };
329 }
330};
331
332static SimpleTestData s_SimpleTestData;
333static Broadcast1ElementTestData s_Broadcast1ElementTestData;
334static Broadcast1dVectorTestData s_Broadcast1dVectorTestData;
335
336} // anonymous namespace
337
338// Equal
339LayerTestResult<uint8_t, 4> EqualSimpleTest(armnn::IWorkloadFactory& workloadFactory,
340 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
341{
342 return ComparisonTestImpl<armnn::DataType::Float32>(
343 workloadFactory,
344 memoryManager,
345 s_SimpleTestData,
346 armnn::ComparisonOperation::Equal);
347}
348
349LayerTestResult<uint8_t, 4> EqualBroadcast1ElementTest(
350 armnn::IWorkloadFactory& workloadFactory,
351 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
352{
353 return ComparisonTestImpl<armnn::DataType::Float32>(
354 workloadFactory,
355 memoryManager,
356 s_Broadcast1ElementTestData,
357 armnn::ComparisonOperation::Equal);
358}
359
360LayerTestResult<uint8_t, 4> EqualBroadcast1dVectorTest(
361 armnn::IWorkloadFactory& workloadFactory,
362 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
363{
364 return ComparisonTestImpl<armnn::DataType::Float32>(
365 workloadFactory,
366 memoryManager,
367 s_Broadcast1dVectorTestData,
368 armnn::ComparisonOperation::Equal);
369}
370
371LayerTestResult<uint8_t, 4> EqualSimpleFloat16Test(
372 armnn::IWorkloadFactory& workloadFactory,
373 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
374{
375 return ComparisonTestImpl<armnn::DataType::Float16>(
376 workloadFactory,
377 memoryManager,
378 s_SimpleTestData,
379 armnn::ComparisonOperation::Equal);
380}
381
382LayerTestResult<uint8_t, 4> EqualBroadcast1ElementFloat16Test(
383 armnn::IWorkloadFactory& workloadFactory,
384 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
385{
386 return ComparisonTestImpl<armnn::DataType::Float16>(
387 workloadFactory,
388 memoryManager,
389 s_Broadcast1ElementTestData,
390 armnn::ComparisonOperation::Equal);
391}
392
393LayerTestResult<uint8_t, 4> EqualBroadcast1dVectorFloat16Test(
394 armnn::IWorkloadFactory& workloadFactory,
395 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
396{
397 return ComparisonTestImpl<armnn::DataType::Float16>(
398 workloadFactory,
399 memoryManager,
400 s_Broadcast1dVectorTestData,
401 armnn::ComparisonOperation::Equal);
402}
403
404LayerTestResult<uint8_t, 4> EqualSimpleUint8Test(
405 armnn::IWorkloadFactory& workloadFactory,
406 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
407{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000408 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100409 workloadFactory,
410 memoryManager,
411 s_SimpleTestData,
412 armnn::ComparisonOperation::Equal);
413}
414
415LayerTestResult<uint8_t, 4> EqualBroadcast1ElementUint8Test(
416 armnn::IWorkloadFactory& workloadFactory,
417 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
418{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000419 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100420 workloadFactory,
421 memoryManager,
422 s_Broadcast1ElementTestData,
423 armnn::ComparisonOperation::Equal);
424}
425
426LayerTestResult<uint8_t, 4> EqualBroadcast1dVectorUint8Test(
427 armnn::IWorkloadFactory& workloadFactory,
428 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
429{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000430 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100431 workloadFactory,
432 memoryManager,
433 s_Broadcast1dVectorTestData,
434 armnn::ComparisonOperation::Equal);
435}
436
437// Greater
438LayerTestResult<uint8_t, 4> GreaterSimpleTest(armnn::IWorkloadFactory& workloadFactory,
439 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
440{
441 return ComparisonTestImpl<armnn::DataType::Float32>(
442 workloadFactory,
443 memoryManager,
444 s_SimpleTestData,
445 armnn::ComparisonOperation::Greater);
446}
447
448LayerTestResult<uint8_t, 4> GreaterBroadcast1ElementTest(
449 armnn::IWorkloadFactory& workloadFactory,
450 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
451{
452 return ComparisonTestImpl<armnn::DataType::Float32>(
453 workloadFactory,
454 memoryManager,
455 s_Broadcast1ElementTestData,
456 armnn::ComparisonOperation::Greater);
457}
458
459LayerTestResult<uint8_t, 4> GreaterBroadcast1dVectorTest(
460 armnn::IWorkloadFactory& workloadFactory,
461 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
462{
463 return ComparisonTestImpl<armnn::DataType::Float32>(
464 workloadFactory,
465 memoryManager,
466 s_Broadcast1dVectorTestData,
467 armnn::ComparisonOperation::Greater);
468}
469
470LayerTestResult<uint8_t, 4> GreaterSimpleFloat16Test(
471 armnn::IWorkloadFactory& workloadFactory,
472 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
473{
474 return ComparisonTestImpl<armnn::DataType::Float16>(
475 workloadFactory,
476 memoryManager,
477 s_SimpleTestData,
478 armnn::ComparisonOperation::Greater);
479}
480
481LayerTestResult<uint8_t, 4> GreaterBroadcast1ElementFloat16Test(
482 armnn::IWorkloadFactory& workloadFactory,
483 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
484{
485 return ComparisonTestImpl<armnn::DataType::Float16>(
486 workloadFactory,
487 memoryManager,
488 s_Broadcast1ElementTestData,
489 armnn::ComparisonOperation::Greater);
490}
491
492LayerTestResult<uint8_t, 4> GreaterBroadcast1dVectorFloat16Test(
493 armnn::IWorkloadFactory& workloadFactory,
494 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
495{
496 return ComparisonTestImpl<armnn::DataType::Float16>(
497 workloadFactory,
498 memoryManager,
499 s_Broadcast1dVectorTestData,
500 armnn::ComparisonOperation::Greater);
501}
502
503LayerTestResult<uint8_t, 4> GreaterSimpleUint8Test(
504 armnn::IWorkloadFactory& workloadFactory,
505 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
506{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000507 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100508 workloadFactory,
509 memoryManager,
510 s_SimpleTestData,
511 armnn::ComparisonOperation::Greater);
512}
513
514LayerTestResult<uint8_t, 4> GreaterBroadcast1ElementUint8Test(
515 armnn::IWorkloadFactory& workloadFactory,
516 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
517{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000518 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100519 workloadFactory,
520 memoryManager,
521 s_Broadcast1ElementTestData,
522 armnn::ComparisonOperation::Greater);
523}
524
525LayerTestResult<uint8_t, 4> GreaterBroadcast1dVectorUint8Test(
526 armnn::IWorkloadFactory& workloadFactory,
527 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
528{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000529 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100530 workloadFactory,
531 memoryManager,
532 s_Broadcast1dVectorTestData,
533 armnn::ComparisonOperation::Greater);
534}
535
536// GreaterOrEqual
537LayerTestResult<uint8_t, 4> GreaterOrEqualSimpleTest(
538 armnn::IWorkloadFactory& workloadFactory,
539 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
540{
541 return ComparisonTestImpl<armnn::DataType::Float32>(
542 workloadFactory,
543 memoryManager,
544 s_SimpleTestData,
545 armnn::ComparisonOperation::GreaterOrEqual);
546}
547
548LayerTestResult<uint8_t, 4> GreaterOrEqualBroadcast1ElementTest(
549 armnn::IWorkloadFactory& workloadFactory,
550 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
551{
552 return ComparisonTestImpl<armnn::DataType::Float32>(
553 workloadFactory,
554 memoryManager,
555 s_Broadcast1ElementTestData,
556 armnn::ComparisonOperation::GreaterOrEqual);
557}
558
559LayerTestResult<uint8_t, 4> GreaterOrEqualBroadcast1dVectorTest(
560 armnn::IWorkloadFactory& workloadFactory,
561 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
562{
563 return ComparisonTestImpl<armnn::DataType::Float32>(
564 workloadFactory,
565 memoryManager,
566 s_Broadcast1dVectorTestData,
567 armnn::ComparisonOperation::GreaterOrEqual);
568}
569
570LayerTestResult<uint8_t, 4> GreaterOrEqualSimpleFloat16Test(
571 armnn::IWorkloadFactory& workloadFactory,
572 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
573{
574 return ComparisonTestImpl<armnn::DataType::Float16>(
575 workloadFactory,
576 memoryManager,
577 s_SimpleTestData,
578 armnn::ComparisonOperation::GreaterOrEqual);
579}
580
581LayerTestResult<uint8_t, 4> GreaterOrEqualBroadcast1ElementFloat16Test(
582 armnn::IWorkloadFactory& workloadFactory,
583 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
584{
585 return ComparisonTestImpl<armnn::DataType::Float16>(
586 workloadFactory,
587 memoryManager,
588 s_Broadcast1ElementTestData,
589 armnn::ComparisonOperation::GreaterOrEqual);
590}
591
592LayerTestResult<uint8_t, 4> GreaterOrEqualBroadcast1dVectorFloat16Test(
593 armnn::IWorkloadFactory& workloadFactory,
594 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
595{
596 return ComparisonTestImpl<armnn::DataType::Float16>(
597 workloadFactory,
598 memoryManager,
599 s_Broadcast1dVectorTestData,
600 armnn::ComparisonOperation::GreaterOrEqual);
601}
602
603LayerTestResult<uint8_t, 4> GreaterOrEqualSimpleUint8Test(
604 armnn::IWorkloadFactory& workloadFactory,
605 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
606{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000607 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100608 workloadFactory,
609 memoryManager,
610 s_SimpleTestData,
611 armnn::ComparisonOperation::GreaterOrEqual);
612}
613
614LayerTestResult<uint8_t, 4> GreaterOrEqualBroadcast1ElementUint8Test(
615 armnn::IWorkloadFactory& workloadFactory,
616 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
617{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000618 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100619 workloadFactory,
620 memoryManager,
621 s_Broadcast1ElementTestData,
622 armnn::ComparisonOperation::GreaterOrEqual);
623}
624
625LayerTestResult<uint8_t, 4> GreaterOrEqualBroadcast1dVectorUint8Test(
626 armnn::IWorkloadFactory& workloadFactory,
627 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
628{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000629 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100630 workloadFactory,
631 memoryManager,
632 s_Broadcast1dVectorTestData,
633 armnn::ComparisonOperation::GreaterOrEqual);
634}
635
636// Less
637LayerTestResult<uint8_t, 4> LessSimpleTest(armnn::IWorkloadFactory& workloadFactory,
638 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
639{
640 return ComparisonTestImpl<armnn::DataType::Float32>(
641 workloadFactory,
642 memoryManager,
643 s_SimpleTestData,
644 armnn::ComparisonOperation::Less);
645}
646
647LayerTestResult<uint8_t, 4> LessBroadcast1ElementTest(
648 armnn::IWorkloadFactory& workloadFactory,
649 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
650{
651 return ComparisonTestImpl<armnn::DataType::Float32>(
652 workloadFactory,
653 memoryManager,
654 s_Broadcast1ElementTestData,
655 armnn::ComparisonOperation::Less);
656}
657
658LayerTestResult<uint8_t, 4> LessBroadcast1dVectorTest(
659 armnn::IWorkloadFactory& workloadFactory,
660 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
661{
662 return ComparisonTestImpl<armnn::DataType::Float32>(
663 workloadFactory,
664 memoryManager,
665 s_Broadcast1dVectorTestData,
666 armnn::ComparisonOperation::Less);
667}
668
669LayerTestResult<uint8_t, 4> LessSimpleFloat16Test(
670 armnn::IWorkloadFactory& workloadFactory,
671 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
672{
673 return ComparisonTestImpl<armnn::DataType::Float16>(
674 workloadFactory,
675 memoryManager,
676 s_SimpleTestData,
677 armnn::ComparisonOperation::Less);
678}
679
680LayerTestResult<uint8_t, 4> LessBroadcast1ElementFloat16Test(
681 armnn::IWorkloadFactory& workloadFactory,
682 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
683{
684 return ComparisonTestImpl<armnn::DataType::Float16>(
685 workloadFactory,
686 memoryManager,
687 s_Broadcast1ElementTestData,
688 armnn::ComparisonOperation::Less);
689}
690
691LayerTestResult<uint8_t, 4> LessBroadcast1dVectorFloat16Test(
692 armnn::IWorkloadFactory& workloadFactory,
693 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
694{
695 return ComparisonTestImpl<armnn::DataType::Float16>(
696 workloadFactory,
697 memoryManager,
698 s_Broadcast1dVectorTestData,
699 armnn::ComparisonOperation::Less);
700}
701
702LayerTestResult<uint8_t, 4> LessSimpleUint8Test(
703 armnn::IWorkloadFactory& workloadFactory,
704 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
705{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000706 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100707 workloadFactory,
708 memoryManager,
709 s_SimpleTestData,
710 armnn::ComparisonOperation::Less);
711}
712
713LayerTestResult<uint8_t, 4> LessBroadcast1ElementUint8Test(
714 armnn::IWorkloadFactory& workloadFactory,
715 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
716{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000717 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100718 workloadFactory,
719 memoryManager,
720 s_Broadcast1ElementTestData,
721 armnn::ComparisonOperation::Less);
722}
723
724LayerTestResult<uint8_t, 4> LessBroadcast1dVectorUint8Test(
725 armnn::IWorkloadFactory& workloadFactory,
726 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
727{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000728 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100729 workloadFactory,
730 memoryManager,
731 s_Broadcast1dVectorTestData,
732 armnn::ComparisonOperation::Less);
733}
734
735// LessOrEqual
736LayerTestResult<uint8_t, 4> LessOrEqualSimpleTest(
737 armnn::IWorkloadFactory& workloadFactory,
738 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
739{
740 return ComparisonTestImpl<armnn::DataType::Float32>(
741 workloadFactory,
742 memoryManager,
743 s_SimpleTestData,
744 armnn::ComparisonOperation::LessOrEqual);
745}
746
747LayerTestResult<uint8_t, 4> LessOrEqualBroadcast1ElementTest(
748 armnn::IWorkloadFactory& workloadFactory,
749 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
750{
751 return ComparisonTestImpl<armnn::DataType::Float32>(
752 workloadFactory,
753 memoryManager,
754 s_Broadcast1ElementTestData,
755 armnn::ComparisonOperation::LessOrEqual);
756}
757
758LayerTestResult<uint8_t, 4> LessOrEqualBroadcast1dVectorTest(
759 armnn::IWorkloadFactory& workloadFactory,
760 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
761{
762 return ComparisonTestImpl<armnn::DataType::Float32>(
763 workloadFactory,
764 memoryManager,
765 s_Broadcast1dVectorTestData,
766 armnn::ComparisonOperation::LessOrEqual);
767}
768
769LayerTestResult<uint8_t, 4> LessOrEqualSimpleFloat16Test(
770 armnn::IWorkloadFactory& workloadFactory,
771 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
772{
773 return ComparisonTestImpl<armnn::DataType::Float16>(
774 workloadFactory,
775 memoryManager,
776 s_SimpleTestData,
777 armnn::ComparisonOperation::LessOrEqual);
778}
779
780LayerTestResult<uint8_t, 4> LessOrEqualBroadcast1ElementFloat16Test(
781 armnn::IWorkloadFactory& workloadFactory,
782 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
783{
784 return ComparisonTestImpl<armnn::DataType::Float16>(
785 workloadFactory,
786 memoryManager,
787 s_Broadcast1ElementTestData,
788 armnn::ComparisonOperation::LessOrEqual);
789}
790
791LayerTestResult<uint8_t, 4> LessOrEqualBroadcast1dVectorFloat16Test(
792 armnn::IWorkloadFactory& workloadFactory,
793 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
794{
795 return ComparisonTestImpl<armnn::DataType::Float16>(
796 workloadFactory,
797 memoryManager,
798 s_Broadcast1dVectorTestData,
799 armnn::ComparisonOperation::LessOrEqual);
800}
801
802LayerTestResult<uint8_t, 4> LessOrEqualSimpleUint8Test(
803 armnn::IWorkloadFactory& workloadFactory,
804 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
805{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000806 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100807 workloadFactory,
808 memoryManager,
809 s_SimpleTestData,
810 armnn::ComparisonOperation::LessOrEqual);
811}
812
813LayerTestResult<uint8_t, 4> LessOrEqualBroadcast1ElementUint8Test(
814 armnn::IWorkloadFactory& workloadFactory,
815 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
816{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000817 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100818 workloadFactory,
819 memoryManager,
820 s_Broadcast1ElementTestData,
821 armnn::ComparisonOperation::LessOrEqual);
822}
823
824LayerTestResult<uint8_t, 4> LessOrEqualBroadcast1dVectorUint8Test(
825 armnn::IWorkloadFactory& workloadFactory,
826 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
827{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000828 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100829 workloadFactory,
830 memoryManager,
831 s_Broadcast1dVectorTestData,
832 armnn::ComparisonOperation::LessOrEqual);
833}
834
835// NotEqual
836LayerTestResult<uint8_t, 4> NotEqualSimpleTest(
837 armnn::IWorkloadFactory& workloadFactory,
838 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
839{
840 return ComparisonTestImpl<armnn::DataType::Float32>(
841 workloadFactory,
842 memoryManager,
843 s_SimpleTestData,
844 armnn::ComparisonOperation::NotEqual);
845}
846
847LayerTestResult<uint8_t, 4> NotEqualBroadcast1ElementTest(
848 armnn::IWorkloadFactory& workloadFactory,
849 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
850{
851 return ComparisonTestImpl<armnn::DataType::Float32>(
852 workloadFactory,
853 memoryManager,
854 s_Broadcast1ElementTestData,
855 armnn::ComparisonOperation::NotEqual);
856}
857
858LayerTestResult<uint8_t, 4> NotEqualBroadcast1dVectorTest(
859 armnn::IWorkloadFactory& workloadFactory,
860 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
861{
862 return ComparisonTestImpl<armnn::DataType::Float32>(
863 workloadFactory,
864 memoryManager,
865 s_Broadcast1dVectorTestData,
866 armnn::ComparisonOperation::NotEqual);
867}
868
869LayerTestResult<uint8_t, 4> NotEqualSimpleFloat16Test(
870 armnn::IWorkloadFactory& workloadFactory,
871 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
872{
873 return ComparisonTestImpl<armnn::DataType::Float16>(
874 workloadFactory,
875 memoryManager,
876 s_SimpleTestData,
877 armnn::ComparisonOperation::NotEqual);
878}
879
880LayerTestResult<uint8_t, 4> NotEqualBroadcast1ElementFloat16Test(
881 armnn::IWorkloadFactory& workloadFactory,
882 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
883{
884 return ComparisonTestImpl<armnn::DataType::Float16>(
885 workloadFactory,
886 memoryManager,
887 s_Broadcast1ElementTestData,
888 armnn::ComparisonOperation::NotEqual);
889}
890
891LayerTestResult<uint8_t, 4> NotEqualBroadcast1dVectorFloat16Test(
892 armnn::IWorkloadFactory& workloadFactory,
893 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
894{
895 return ComparisonTestImpl<armnn::DataType::Float16>(
896 workloadFactory,
897 memoryManager,
898 s_Broadcast1dVectorTestData,
899 armnn::ComparisonOperation::NotEqual);
900}
901
902LayerTestResult<uint8_t, 4> NotEqualSimpleUint8Test(
903 armnn::IWorkloadFactory& workloadFactory,
904 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
905{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000906 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100907 workloadFactory,
908 memoryManager,
909 s_SimpleTestData,
910 armnn::ComparisonOperation::NotEqual);
911}
912
913LayerTestResult<uint8_t, 4> NotEqualBroadcast1ElementUint8Test(
914 armnn::IWorkloadFactory& workloadFactory,
915 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
916{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000917 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100918 workloadFactory,
919 memoryManager,
920 s_Broadcast1ElementTestData,
921 armnn::ComparisonOperation::NotEqual);
922}
923
924LayerTestResult<uint8_t, 4> NotEqualBroadcast1dVectorUint8Test(
925 armnn::IWorkloadFactory& workloadFactory,
926 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
927{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000928 return ComparisonTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100929 workloadFactory,
930 memoryManager,
931 s_Broadcast1dVectorTestData,
932 armnn::ComparisonOperation::NotEqual);
933}