blob: bb8bb04ad303f55294d654ff314ee1972c342729 [file] [log] [blame]
kevmay012b4d88e2019-01-24 14:05:09 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefComparisonWorkload.hpp"
7#include "ElementwiseFunction.hpp"
8#include "RefWorkloadUtils.hpp"
9#include "Profiling.hpp"
10#include <vector>
11
12namespace armnn {
13
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010014template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
15void RefComparisonWorkload<Functor, ParentDescriptor, DebugString>::Execute() const
kevmay012b4d88e2019-01-24 14:05:09 +000016{
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010017 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString));
18 const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
19 const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
20 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
kevmay012b4d88e2019-01-24 14:05:09 +000021
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010022 const TensorShape& inShape0 = inputInfo0.GetShape();
23 const TensorShape& inShape1 = inputInfo1.GetShape();
24 const TensorShape& outShape = outputInfo.GetShape();
kevmay012b4d88e2019-01-24 14:05:09 +000025
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010026 switch(inputInfo0.GetDataType())
27 {
28 case armnn::DataType::QuantisedAsymm8:
29 {
30 QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data),
31 inputInfo0.GetQuantizationScale(),
32 inputInfo0.GetQuantizationOffset());
kevmay012b4d88e2019-01-24 14:05:09 +000033
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010034 QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data),
35 inputInfo1.GetQuantizationScale(),
36 inputInfo1.GetQuantizationOffset());
kevmay012b4d88e2019-01-24 14:05:09 +000037
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010038 BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data));
kevmay012b4d88e2019-01-24 14:05:09 +000039
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010040 ElementwiseFunction<Functor, Decoder, ComparisonEncoder>(inShape0,
41 inShape1,
42 outShape,
43 decodeIterator0,
44 decodeIterator1,
45 encodeIterator0);
46 break;
47 }
48 case armnn::DataType::Float32:
49 {
50 FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data));
51 FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data));
52 BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data));
kevmay012b4d88e2019-01-24 14:05:09 +000053
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010054 ElementwiseFunction<Functor, Decoder, ComparisonEncoder>(inShape0,
55 inShape1,
56 outShape,
57 decodeIterator0,
58 decodeIterator1,
59 encodeIterator0);
60 break;
61 }
62 default:
63 BOOST_ASSERT_MSG(false, "RefComparisonWorkload: Not supported Data Type!");
64 break;
65 }
kevmay012b4d88e2019-01-24 14:05:09 +000066}
67
68}
69
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010070template class armnn::RefComparisonWorkload<std::equal_to<float>,
71 armnn::EqualQueueDescriptor,
72 armnn::StringMapping::RefEqualWorkload_Execute>;
kevmay012b4d88e2019-01-24 14:05:09 +000073
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010074template class armnn::RefComparisonWorkload<std::greater<float>,
75 armnn::GreaterQueueDescriptor,
76 armnn::StringMapping::RefGreaterWorkload_Execute>;