blob: c8272f47f0f05aa34c83242e7fefdd76ba540e6c [file] [log] [blame]
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ComparisonTestImpl.hpp"
7
8#include <armnn/ArmNN.hpp>
9
10#include <Half.hpp>
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010011#include <QuantizeHelper.hpp>
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +010012#include <ResolveType.hpp>
13
14#include <backendsCommon/Workload.hpp>
15#include <backendsCommon/WorkloadData.hpp>
16
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +010017#include <backendsCommon/test/TensorCopyUtils.hpp>
18#include <backendsCommon/test/WorkloadTestUtils.hpp>
19
20#include <test/TensorHelpers.hpp>
21
22#include <boost/assert.hpp>
23
24namespace
25{
26
27template <std::size_t NumDims,
28 armnn::DataType ArmnnInType,
29 typename InType = armnn::ResolveType<ArmnnInType>>
30LayerTestResult<uint8_t, NumDims> ComparisonTestImpl(
31 armnn::IWorkloadFactory & workloadFactory,
32 const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,
33 const armnn::ComparisonDescriptor& descriptor,
34 const armnn::TensorShape& shape0,
35 std::vector<InType> values0,
36 float quantScale0,
37 int quantOffset0,
38 const armnn::TensorShape& shape1,
39 std::vector<InType> values1,
40 float quantScale1,
41 int quantOffset1,
42 const armnn::TensorShape& outShape,
43 std::vector<uint8_t> outValues,
44 float outQuantScale,
45 int outQuantOffset)
46{
Derek Lambertic374ff02019-12-10 21:57:35 +000047 boost::ignore_unused(memoryManager);
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +010048 BOOST_ASSERT(shape0.GetNumDimensions() == NumDims);
49 armnn::TensorInfo inputTensorInfo0(shape0, ArmnnInType, quantScale0, quantOffset0);
50
51 BOOST_ASSERT(shape1.GetNumDimensions() == NumDims);
52 armnn::TensorInfo inputTensorInfo1(shape1, ArmnnInType, quantScale1, quantOffset1);
53
54 BOOST_ASSERT(outShape.GetNumDimensions() == NumDims);
55 armnn::TensorInfo outputTensorInfo(outShape, armnn::DataType::Boolean, outQuantScale, outQuantOffset);
56
57 auto input0 = MakeTensor<InType, NumDims>(inputTensorInfo0, values0);
58 auto input1 = MakeTensor<InType, NumDims>(inputTensorInfo1, values1);
59
60 LayerTestResult<uint8_t, NumDims> ret(outputTensorInfo);
61
62 std::unique_ptr<armnn::ITensorHandle> inputHandle0 = workloadFactory.CreateTensorHandle(inputTensorInfo0);
63 std::unique_ptr<armnn::ITensorHandle> inputHandle1 = workloadFactory.CreateTensorHandle(inputTensorInfo1);
64 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
65
66 armnn::ComparisonQueueDescriptor qDescriptor;
67 qDescriptor.m_Parameters = descriptor;
68
69 armnn::WorkloadInfo info;
70 AddInputToWorkload(qDescriptor, info, inputTensorInfo0, inputHandle0.get());
71 AddInputToWorkload(qDescriptor, info, inputTensorInfo1, inputHandle1.get());
72 AddOutputToWorkload(qDescriptor, info, outputTensorInfo, outputHandle.get());
73
74 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateComparison(qDescriptor, info);
75
76 inputHandle0->Allocate();
77 inputHandle1->Allocate();
78 outputHandle->Allocate();
79
80 CopyDataToITensorHandle(inputHandle0.get(), input0.origin());
81 CopyDataToITensorHandle(inputHandle1.get(), input1.origin());
82
83 workload->PostAllocationConfigure();
84 ExecuteWorkload(*workload, memoryManager);
85
86 CopyDataFromITensorHandle(ret.output.origin(), outputHandle.get());
87
88 ret.outputExpected = MakeTensor<uint8_t, NumDims>(outputTensorInfo, outValues);
89 ret.compareBoolean = true;
90
91 return ret;
92}
93
94template <std::size_t NumDims,
95 armnn::DataType ArmnnInType,
96 typename InType = armnn::ResolveType<ArmnnInType>>
97LayerTestResult<uint8_t, NumDims> ComparisonTestImpl(
98 armnn::IWorkloadFactory & workloadFactory,
99 const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,
100 const armnn::ComparisonDescriptor& descriptor,
101 const armnn::TensorShape& shape0,
102 std::vector<InType> values0,
103 const armnn::TensorShape& shape1,
104 std::vector<InType> values1,
105 const armnn::TensorShape outShape,
106 std::vector<uint8_t> outValues,
107 float quantScale = 1.f,
108 int quantOffset = 0)
109{
110 return ComparisonTestImpl<NumDims, ArmnnInType>(
111 workloadFactory,
112 memoryManager,
113 descriptor,
114 shape0,
115 values0,
116 quantScale,
117 quantOffset,
118 shape1,
119 values1,
120 quantScale,
121 quantOffset,
122 outShape,
123 outValues,
124 quantScale,
125 quantOffset);
126}
127
128template<typename TestData>
129std::vector<uint8_t> GetExpectedOutputData(const TestData& testData, armnn::ComparisonOperation operation)
130{
131 switch (operation)
132 {
133 case armnn::ComparisonOperation::Equal:
134 return testData.m_OutputEqual;
135 case armnn::ComparisonOperation::Greater:
136 return testData.m_OutputGreater;
137 case armnn::ComparisonOperation::GreaterOrEqual:
138 return testData.m_OutputGreaterOrEqual;
139 case armnn::ComparisonOperation::Less:
140 return testData.m_OutputLess;
141 case armnn::ComparisonOperation::LessOrEqual:
142 return testData.m_OutputLessOrEqual;
143 case armnn::ComparisonOperation::NotEqual:
144 default:
145 return testData.m_OutputNotEqual;
146 }
147}
148
149template<armnn::DataType ArmnnInType, typename TestData>
150LayerTestResult<uint8_t, 4> ComparisonTestImpl(armnn::IWorkloadFactory& workloadFactory,
151 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
152 const TestData& testData,
153 armnn::ComparisonOperation operation,
154 float quantScale = 1.f,
155 int quantOffset = 0)
156{
157 using T = armnn::ResolveType<ArmnnInType>;
158
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100159 std::vector<T> inputData0 = armnnUtils::QuantizedVector<T>(testData.m_InputData0, quantScale, quantOffset);
160 std::vector<T> inputData1 = armnnUtils::QuantizedVector<T>(testData.m_InputData1, quantScale, quantOffset);
Aron Virginas-Tar3bc00ec2019-10-18 15:42:58 +0100161
162 return ComparisonTestImpl<4, ArmnnInType>(
163 workloadFactory,
164 memoryManager,
165 armnn::ComparisonDescriptor(operation),
166 testData.m_InputShape0,
167 inputData0,
168 testData.m_InputShape1,
169 inputData1,
170 testData.m_OutputShape,
171 GetExpectedOutputData(testData, operation),
172 quantScale,
173 quantOffset);
174}
175
176class ComparisonTestData
177{
178public:
179 ComparisonTestData() = default;
180 virtual ~ComparisonTestData() = default;
181
182 armnn::TensorShape m_InputShape0;
183 armnn::TensorShape m_InputShape1;
184 armnn::TensorShape m_OutputShape;
185
186 std::vector<float> m_InputData0;
187 std::vector<float> m_InputData1;
188
189 std::vector<uint8_t> m_OutputEqual;
190 std::vector<uint8_t> m_OutputGreater;
191 std::vector<uint8_t> m_OutputGreaterOrEqual;
192 std::vector<uint8_t> m_OutputLess;
193 std::vector<uint8_t> m_OutputLessOrEqual;
194 std::vector<uint8_t> m_OutputNotEqual;
195};
196
197class SimpleTestData : public ComparisonTestData
198{
199public:
200 SimpleTestData() : ComparisonTestData()
201 {
202 m_InputShape0 = { 2, 2, 2, 2 };
203
204 m_InputShape1 = m_InputShape0;
205 m_OutputShape = m_InputShape0;
206
207 m_InputData0 =
208 {
209 1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
210 3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f
211 };
212
213 m_InputData1 =
214 {
215 1.f, 1.f, 1.f, 1.f, 3.f, 3.f, 3.f, 3.f,
216 5.f, 5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 4.f
217 };
218
219 m_OutputEqual =
220 {
221 1, 1, 1, 1, 0, 0, 0, 0,
222 0, 0, 0, 0, 1, 1, 1, 1
223 };
224
225 m_OutputGreater =
226 {
227 0, 0, 0, 0, 1, 1, 1, 1,
228 0, 0, 0, 0, 0, 0, 0, 0
229 };
230
231 m_OutputGreaterOrEqual =
232 {
233 1, 1, 1, 1, 1, 1, 1, 1,
234 0, 0, 0, 0, 1, 1, 1, 1
235 };
236
237 m_OutputLess =
238 {
239 0, 0, 0, 0, 0, 0, 0, 0,
240 1, 1, 1, 1, 0, 0, 0, 0
241 };
242
243 m_OutputLessOrEqual =
244 {
245 1, 1, 1, 1, 0, 0, 0, 0,
246 1, 1, 1, 1, 1, 1, 1, 1
247 };
248
249 m_OutputNotEqual =
250 {
251 0, 0, 0, 0, 1, 1, 1, 1,
252 1, 1, 1, 1, 0, 0, 0, 0
253 };
254 }
255};
256
257class Broadcast1ElementTestData : public ComparisonTestData
258{
259public:
260 Broadcast1ElementTestData() : ComparisonTestData()
261 {
262 m_InputShape0 = { 1, 2, 2, 2 };
263 m_InputShape1 = { 1, 1, 1, 1 };
264
265 m_OutputShape = m_InputShape0;
266
267 m_InputData0 = { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f };
268 m_InputData1 = { 3.f };
269
270 m_OutputEqual = { 0, 0, 1, 0, 0, 0, 0, 0 };
271 m_OutputGreater = { 0, 0, 0, 1, 1, 1, 1, 1 };
272 m_OutputGreaterOrEqual = { 0, 0, 1, 1, 1, 1, 1, 1 };
273 m_OutputLess = { 1, 1, 0, 0, 0, 0, 0, 0 };
274 m_OutputLessOrEqual = { 1, 1, 1, 0, 0, 0, 0, 0 };
275 m_OutputNotEqual = { 1, 1, 0, 1, 1, 1, 1, 1 };
276 }
277};
278
279class Broadcast1dVectorTestData : public ComparisonTestData
280{
281public:
282 Broadcast1dVectorTestData() : ComparisonTestData()
283 {
284 m_InputShape0 = { 1, 2, 2, 3 };
285 m_InputShape1 = { 1, 1, 1, 3 };
286
287 m_OutputShape = m_InputShape0;
288
289 m_InputData0 =
290 {
291 1.f, 2.f, 3.f, 4.f, 5.f, 6.f,
292 7.f, 8.f, 9.f, 10.f, 11.f, 12.f
293 };
294
295 m_InputData1 = { 4.f, 5.f, 6.f };
296
297 m_OutputEqual =
298 {
299 0, 0, 0, 1, 1, 1,
300 0, 0, 0, 0, 0, 0
301 };
302
303 m_OutputGreater =
304 {
305 0, 0, 0, 0, 0, 0,
306 1, 1, 1, 1, 1, 1
307 };
308
309 m_OutputGreaterOrEqual =
310 {
311 0, 0, 0, 1, 1, 1,
312 1, 1, 1, 1, 1, 1
313 };
314
315 m_OutputLess =
316 {
317 1, 1, 1, 0, 0, 0,
318 0, 0, 0, 0, 0, 0
319 };
320
321 m_OutputLessOrEqual =
322 {
323 1, 1, 1, 1, 1, 1,
324 0, 0, 0, 0, 0, 0
325 };
326
327 m_OutputNotEqual =
328 {
329 1, 1, 1, 0, 0, 0,
330 1, 1, 1, 1, 1, 1
331 };
332 }
333};
334
335static SimpleTestData s_SimpleTestData;
336static Broadcast1ElementTestData s_Broadcast1ElementTestData;
337static Broadcast1dVectorTestData s_Broadcast1dVectorTestData;
338
339} // anonymous namespace
340
341// Equal
342LayerTestResult<uint8_t, 4> EqualSimpleTest(armnn::IWorkloadFactory& workloadFactory,
343 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
344{
345 return ComparisonTestImpl<armnn::DataType::Float32>(
346 workloadFactory,
347 memoryManager,
348 s_SimpleTestData,
349 armnn::ComparisonOperation::Equal);
350}
351
352LayerTestResult<uint8_t, 4> EqualBroadcast1ElementTest(
353 armnn::IWorkloadFactory& workloadFactory,
354 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
355{
356 return ComparisonTestImpl<armnn::DataType::Float32>(
357 workloadFactory,
358 memoryManager,
359 s_Broadcast1ElementTestData,
360 armnn::ComparisonOperation::Equal);
361}
362
363LayerTestResult<uint8_t, 4> EqualBroadcast1dVectorTest(
364 armnn::IWorkloadFactory& workloadFactory,
365 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
366{
367 return ComparisonTestImpl<armnn::DataType::Float32>(
368 workloadFactory,
369 memoryManager,
370 s_Broadcast1dVectorTestData,
371 armnn::ComparisonOperation::Equal);
372}
373
374LayerTestResult<uint8_t, 4> EqualSimpleFloat16Test(
375 armnn::IWorkloadFactory& workloadFactory,
376 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
377{
378 return ComparisonTestImpl<armnn::DataType::Float16>(
379 workloadFactory,
380 memoryManager,
381 s_SimpleTestData,
382 armnn::ComparisonOperation::Equal);
383}
384
385LayerTestResult<uint8_t, 4> EqualBroadcast1ElementFloat16Test(
386 armnn::IWorkloadFactory& workloadFactory,
387 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
388{
389 return ComparisonTestImpl<armnn::DataType::Float16>(
390 workloadFactory,
391 memoryManager,
392 s_Broadcast1ElementTestData,
393 armnn::ComparisonOperation::Equal);
394}
395
396LayerTestResult<uint8_t, 4> EqualBroadcast1dVectorFloat16Test(
397 armnn::IWorkloadFactory& workloadFactory,
398 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
399{
400 return ComparisonTestImpl<armnn::DataType::Float16>(
401 workloadFactory,
402 memoryManager,
403 s_Broadcast1dVectorTestData,
404 armnn::ComparisonOperation::Equal);
405}
406
407LayerTestResult<uint8_t, 4> EqualSimpleUint8Test(
408 armnn::IWorkloadFactory& workloadFactory,
409 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
410{
411 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
412 workloadFactory,
413 memoryManager,
414 s_SimpleTestData,
415 armnn::ComparisonOperation::Equal);
416}
417
418LayerTestResult<uint8_t, 4> EqualBroadcast1ElementUint8Test(
419 armnn::IWorkloadFactory& workloadFactory,
420 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
421{
422 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
423 workloadFactory,
424 memoryManager,
425 s_Broadcast1ElementTestData,
426 armnn::ComparisonOperation::Equal);
427}
428
429LayerTestResult<uint8_t, 4> EqualBroadcast1dVectorUint8Test(
430 armnn::IWorkloadFactory& workloadFactory,
431 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
432{
433 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
434 workloadFactory,
435 memoryManager,
436 s_Broadcast1dVectorTestData,
437 armnn::ComparisonOperation::Equal);
438}
439
440// Greater
441LayerTestResult<uint8_t, 4> GreaterSimpleTest(armnn::IWorkloadFactory& workloadFactory,
442 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
443{
444 return ComparisonTestImpl<armnn::DataType::Float32>(
445 workloadFactory,
446 memoryManager,
447 s_SimpleTestData,
448 armnn::ComparisonOperation::Greater);
449}
450
451LayerTestResult<uint8_t, 4> GreaterBroadcast1ElementTest(
452 armnn::IWorkloadFactory& workloadFactory,
453 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
454{
455 return ComparisonTestImpl<armnn::DataType::Float32>(
456 workloadFactory,
457 memoryManager,
458 s_Broadcast1ElementTestData,
459 armnn::ComparisonOperation::Greater);
460}
461
462LayerTestResult<uint8_t, 4> GreaterBroadcast1dVectorTest(
463 armnn::IWorkloadFactory& workloadFactory,
464 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
465{
466 return ComparisonTestImpl<armnn::DataType::Float32>(
467 workloadFactory,
468 memoryManager,
469 s_Broadcast1dVectorTestData,
470 armnn::ComparisonOperation::Greater);
471}
472
473LayerTestResult<uint8_t, 4> GreaterSimpleFloat16Test(
474 armnn::IWorkloadFactory& workloadFactory,
475 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
476{
477 return ComparisonTestImpl<armnn::DataType::Float16>(
478 workloadFactory,
479 memoryManager,
480 s_SimpleTestData,
481 armnn::ComparisonOperation::Greater);
482}
483
484LayerTestResult<uint8_t, 4> GreaterBroadcast1ElementFloat16Test(
485 armnn::IWorkloadFactory& workloadFactory,
486 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
487{
488 return ComparisonTestImpl<armnn::DataType::Float16>(
489 workloadFactory,
490 memoryManager,
491 s_Broadcast1ElementTestData,
492 armnn::ComparisonOperation::Greater);
493}
494
495LayerTestResult<uint8_t, 4> GreaterBroadcast1dVectorFloat16Test(
496 armnn::IWorkloadFactory& workloadFactory,
497 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
498{
499 return ComparisonTestImpl<armnn::DataType::Float16>(
500 workloadFactory,
501 memoryManager,
502 s_Broadcast1dVectorTestData,
503 armnn::ComparisonOperation::Greater);
504}
505
506LayerTestResult<uint8_t, 4> GreaterSimpleUint8Test(
507 armnn::IWorkloadFactory& workloadFactory,
508 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
509{
510 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
511 workloadFactory,
512 memoryManager,
513 s_SimpleTestData,
514 armnn::ComparisonOperation::Greater);
515}
516
517LayerTestResult<uint8_t, 4> GreaterBroadcast1ElementUint8Test(
518 armnn::IWorkloadFactory& workloadFactory,
519 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
520{
521 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
522 workloadFactory,
523 memoryManager,
524 s_Broadcast1ElementTestData,
525 armnn::ComparisonOperation::Greater);
526}
527
528LayerTestResult<uint8_t, 4> GreaterBroadcast1dVectorUint8Test(
529 armnn::IWorkloadFactory& workloadFactory,
530 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
531{
532 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
533 workloadFactory,
534 memoryManager,
535 s_Broadcast1dVectorTestData,
536 armnn::ComparisonOperation::Greater);
537}
538
539// GreaterOrEqual
540LayerTestResult<uint8_t, 4> GreaterOrEqualSimpleTest(
541 armnn::IWorkloadFactory& workloadFactory,
542 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
543{
544 return ComparisonTestImpl<armnn::DataType::Float32>(
545 workloadFactory,
546 memoryManager,
547 s_SimpleTestData,
548 armnn::ComparisonOperation::GreaterOrEqual);
549}
550
551LayerTestResult<uint8_t, 4> GreaterOrEqualBroadcast1ElementTest(
552 armnn::IWorkloadFactory& workloadFactory,
553 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
554{
555 return ComparisonTestImpl<armnn::DataType::Float32>(
556 workloadFactory,
557 memoryManager,
558 s_Broadcast1ElementTestData,
559 armnn::ComparisonOperation::GreaterOrEqual);
560}
561
562LayerTestResult<uint8_t, 4> GreaterOrEqualBroadcast1dVectorTest(
563 armnn::IWorkloadFactory& workloadFactory,
564 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
565{
566 return ComparisonTestImpl<armnn::DataType::Float32>(
567 workloadFactory,
568 memoryManager,
569 s_Broadcast1dVectorTestData,
570 armnn::ComparisonOperation::GreaterOrEqual);
571}
572
573LayerTestResult<uint8_t, 4> GreaterOrEqualSimpleFloat16Test(
574 armnn::IWorkloadFactory& workloadFactory,
575 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
576{
577 return ComparisonTestImpl<armnn::DataType::Float16>(
578 workloadFactory,
579 memoryManager,
580 s_SimpleTestData,
581 armnn::ComparisonOperation::GreaterOrEqual);
582}
583
584LayerTestResult<uint8_t, 4> GreaterOrEqualBroadcast1ElementFloat16Test(
585 armnn::IWorkloadFactory& workloadFactory,
586 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
587{
588 return ComparisonTestImpl<armnn::DataType::Float16>(
589 workloadFactory,
590 memoryManager,
591 s_Broadcast1ElementTestData,
592 armnn::ComparisonOperation::GreaterOrEqual);
593}
594
595LayerTestResult<uint8_t, 4> GreaterOrEqualBroadcast1dVectorFloat16Test(
596 armnn::IWorkloadFactory& workloadFactory,
597 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
598{
599 return ComparisonTestImpl<armnn::DataType::Float16>(
600 workloadFactory,
601 memoryManager,
602 s_Broadcast1dVectorTestData,
603 armnn::ComparisonOperation::GreaterOrEqual);
604}
605
606LayerTestResult<uint8_t, 4> GreaterOrEqualSimpleUint8Test(
607 armnn::IWorkloadFactory& workloadFactory,
608 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
609{
610 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
611 workloadFactory,
612 memoryManager,
613 s_SimpleTestData,
614 armnn::ComparisonOperation::GreaterOrEqual);
615}
616
617LayerTestResult<uint8_t, 4> GreaterOrEqualBroadcast1ElementUint8Test(
618 armnn::IWorkloadFactory& workloadFactory,
619 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
620{
621 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
622 workloadFactory,
623 memoryManager,
624 s_Broadcast1ElementTestData,
625 armnn::ComparisonOperation::GreaterOrEqual);
626}
627
628LayerTestResult<uint8_t, 4> GreaterOrEqualBroadcast1dVectorUint8Test(
629 armnn::IWorkloadFactory& workloadFactory,
630 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
631{
632 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
633 workloadFactory,
634 memoryManager,
635 s_Broadcast1dVectorTestData,
636 armnn::ComparisonOperation::GreaterOrEqual);
637}
638
639// Less
640LayerTestResult<uint8_t, 4> LessSimpleTest(armnn::IWorkloadFactory& workloadFactory,
641 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
642{
643 return ComparisonTestImpl<armnn::DataType::Float32>(
644 workloadFactory,
645 memoryManager,
646 s_SimpleTestData,
647 armnn::ComparisonOperation::Less);
648}
649
650LayerTestResult<uint8_t, 4> LessBroadcast1ElementTest(
651 armnn::IWorkloadFactory& workloadFactory,
652 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
653{
654 return ComparisonTestImpl<armnn::DataType::Float32>(
655 workloadFactory,
656 memoryManager,
657 s_Broadcast1ElementTestData,
658 armnn::ComparisonOperation::Less);
659}
660
661LayerTestResult<uint8_t, 4> LessBroadcast1dVectorTest(
662 armnn::IWorkloadFactory& workloadFactory,
663 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
664{
665 return ComparisonTestImpl<armnn::DataType::Float32>(
666 workloadFactory,
667 memoryManager,
668 s_Broadcast1dVectorTestData,
669 armnn::ComparisonOperation::Less);
670}
671
672LayerTestResult<uint8_t, 4> LessSimpleFloat16Test(
673 armnn::IWorkloadFactory& workloadFactory,
674 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
675{
676 return ComparisonTestImpl<armnn::DataType::Float16>(
677 workloadFactory,
678 memoryManager,
679 s_SimpleTestData,
680 armnn::ComparisonOperation::Less);
681}
682
683LayerTestResult<uint8_t, 4> LessBroadcast1ElementFloat16Test(
684 armnn::IWorkloadFactory& workloadFactory,
685 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
686{
687 return ComparisonTestImpl<armnn::DataType::Float16>(
688 workloadFactory,
689 memoryManager,
690 s_Broadcast1ElementTestData,
691 armnn::ComparisonOperation::Less);
692}
693
694LayerTestResult<uint8_t, 4> LessBroadcast1dVectorFloat16Test(
695 armnn::IWorkloadFactory& workloadFactory,
696 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
697{
698 return ComparisonTestImpl<armnn::DataType::Float16>(
699 workloadFactory,
700 memoryManager,
701 s_Broadcast1dVectorTestData,
702 armnn::ComparisonOperation::Less);
703}
704
705LayerTestResult<uint8_t, 4> LessSimpleUint8Test(
706 armnn::IWorkloadFactory& workloadFactory,
707 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
708{
709 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
710 workloadFactory,
711 memoryManager,
712 s_SimpleTestData,
713 armnn::ComparisonOperation::Less);
714}
715
716LayerTestResult<uint8_t, 4> LessBroadcast1ElementUint8Test(
717 armnn::IWorkloadFactory& workloadFactory,
718 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
719{
720 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
721 workloadFactory,
722 memoryManager,
723 s_Broadcast1ElementTestData,
724 armnn::ComparisonOperation::Less);
725}
726
727LayerTestResult<uint8_t, 4> LessBroadcast1dVectorUint8Test(
728 armnn::IWorkloadFactory& workloadFactory,
729 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
730{
731 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
732 workloadFactory,
733 memoryManager,
734 s_Broadcast1dVectorTestData,
735 armnn::ComparisonOperation::Less);
736}
737
738// LessOrEqual
739LayerTestResult<uint8_t, 4> LessOrEqualSimpleTest(
740 armnn::IWorkloadFactory& workloadFactory,
741 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
742{
743 return ComparisonTestImpl<armnn::DataType::Float32>(
744 workloadFactory,
745 memoryManager,
746 s_SimpleTestData,
747 armnn::ComparisonOperation::LessOrEqual);
748}
749
750LayerTestResult<uint8_t, 4> LessOrEqualBroadcast1ElementTest(
751 armnn::IWorkloadFactory& workloadFactory,
752 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
753{
754 return ComparisonTestImpl<armnn::DataType::Float32>(
755 workloadFactory,
756 memoryManager,
757 s_Broadcast1ElementTestData,
758 armnn::ComparisonOperation::LessOrEqual);
759}
760
761LayerTestResult<uint8_t, 4> LessOrEqualBroadcast1dVectorTest(
762 armnn::IWorkloadFactory& workloadFactory,
763 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
764{
765 return ComparisonTestImpl<armnn::DataType::Float32>(
766 workloadFactory,
767 memoryManager,
768 s_Broadcast1dVectorTestData,
769 armnn::ComparisonOperation::LessOrEqual);
770}
771
772LayerTestResult<uint8_t, 4> LessOrEqualSimpleFloat16Test(
773 armnn::IWorkloadFactory& workloadFactory,
774 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
775{
776 return ComparisonTestImpl<armnn::DataType::Float16>(
777 workloadFactory,
778 memoryManager,
779 s_SimpleTestData,
780 armnn::ComparisonOperation::LessOrEqual);
781}
782
783LayerTestResult<uint8_t, 4> LessOrEqualBroadcast1ElementFloat16Test(
784 armnn::IWorkloadFactory& workloadFactory,
785 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
786{
787 return ComparisonTestImpl<armnn::DataType::Float16>(
788 workloadFactory,
789 memoryManager,
790 s_Broadcast1ElementTestData,
791 armnn::ComparisonOperation::LessOrEqual);
792}
793
794LayerTestResult<uint8_t, 4> LessOrEqualBroadcast1dVectorFloat16Test(
795 armnn::IWorkloadFactory& workloadFactory,
796 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
797{
798 return ComparisonTestImpl<armnn::DataType::Float16>(
799 workloadFactory,
800 memoryManager,
801 s_Broadcast1dVectorTestData,
802 armnn::ComparisonOperation::LessOrEqual);
803}
804
805LayerTestResult<uint8_t, 4> LessOrEqualSimpleUint8Test(
806 armnn::IWorkloadFactory& workloadFactory,
807 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
808{
809 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
810 workloadFactory,
811 memoryManager,
812 s_SimpleTestData,
813 armnn::ComparisonOperation::LessOrEqual);
814}
815
816LayerTestResult<uint8_t, 4> LessOrEqualBroadcast1ElementUint8Test(
817 armnn::IWorkloadFactory& workloadFactory,
818 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
819{
820 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
821 workloadFactory,
822 memoryManager,
823 s_Broadcast1ElementTestData,
824 armnn::ComparisonOperation::LessOrEqual);
825}
826
827LayerTestResult<uint8_t, 4> LessOrEqualBroadcast1dVectorUint8Test(
828 armnn::IWorkloadFactory& workloadFactory,
829 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
830{
831 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
832 workloadFactory,
833 memoryManager,
834 s_Broadcast1dVectorTestData,
835 armnn::ComparisonOperation::LessOrEqual);
836}
837
838// NotEqual
839LayerTestResult<uint8_t, 4> NotEqualSimpleTest(
840 armnn::IWorkloadFactory& workloadFactory,
841 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
842{
843 return ComparisonTestImpl<armnn::DataType::Float32>(
844 workloadFactory,
845 memoryManager,
846 s_SimpleTestData,
847 armnn::ComparisonOperation::NotEqual);
848}
849
850LayerTestResult<uint8_t, 4> NotEqualBroadcast1ElementTest(
851 armnn::IWorkloadFactory& workloadFactory,
852 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
853{
854 return ComparisonTestImpl<armnn::DataType::Float32>(
855 workloadFactory,
856 memoryManager,
857 s_Broadcast1ElementTestData,
858 armnn::ComparisonOperation::NotEqual);
859}
860
861LayerTestResult<uint8_t, 4> NotEqualBroadcast1dVectorTest(
862 armnn::IWorkloadFactory& workloadFactory,
863 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
864{
865 return ComparisonTestImpl<armnn::DataType::Float32>(
866 workloadFactory,
867 memoryManager,
868 s_Broadcast1dVectorTestData,
869 armnn::ComparisonOperation::NotEqual);
870}
871
872LayerTestResult<uint8_t, 4> NotEqualSimpleFloat16Test(
873 armnn::IWorkloadFactory& workloadFactory,
874 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
875{
876 return ComparisonTestImpl<armnn::DataType::Float16>(
877 workloadFactory,
878 memoryManager,
879 s_SimpleTestData,
880 armnn::ComparisonOperation::NotEqual);
881}
882
883LayerTestResult<uint8_t, 4> NotEqualBroadcast1ElementFloat16Test(
884 armnn::IWorkloadFactory& workloadFactory,
885 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
886{
887 return ComparisonTestImpl<armnn::DataType::Float16>(
888 workloadFactory,
889 memoryManager,
890 s_Broadcast1ElementTestData,
891 armnn::ComparisonOperation::NotEqual);
892}
893
894LayerTestResult<uint8_t, 4> NotEqualBroadcast1dVectorFloat16Test(
895 armnn::IWorkloadFactory& workloadFactory,
896 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
897{
898 return ComparisonTestImpl<armnn::DataType::Float16>(
899 workloadFactory,
900 memoryManager,
901 s_Broadcast1dVectorTestData,
902 armnn::ComparisonOperation::NotEqual);
903}
904
905LayerTestResult<uint8_t, 4> NotEqualSimpleUint8Test(
906 armnn::IWorkloadFactory& workloadFactory,
907 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
908{
909 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
910 workloadFactory,
911 memoryManager,
912 s_SimpleTestData,
913 armnn::ComparisonOperation::NotEqual);
914}
915
916LayerTestResult<uint8_t, 4> NotEqualBroadcast1ElementUint8Test(
917 armnn::IWorkloadFactory& workloadFactory,
918 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
919{
920 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
921 workloadFactory,
922 memoryManager,
923 s_Broadcast1ElementTestData,
924 armnn::ComparisonOperation::NotEqual);
925}
926
927LayerTestResult<uint8_t, 4> NotEqualBroadcast1dVectorUint8Test(
928 armnn::IWorkloadFactory& workloadFactory,
929 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
930{
931 return ComparisonTestImpl<armnn::DataType::QuantisedAsymm8>(
932 workloadFactory,
933 memoryManager,
934 s_Broadcast1dVectorTestData,
935 armnn::ComparisonOperation::NotEqual);
936}