blob: 6e6e1d5f21677d924601727a1795f22fc4badc6a [file] [log] [blame]
Éanna Ó Catháind57415d2018-11-28 16:24:38 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefElementwiseWorkload.hpp"
7#include "ElementwiseFunction.hpp"
8#include "RefWorkloadUtils.hpp"
9#include "Profiling.hpp"
Sadik Armaganef38d5d2019-03-25 09:03:35 +000010#include "StringMapping.hpp"
11#include "TypeUtils.hpp"
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000012#include <vector>
13
14namespace armnn
15{
16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010017template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
18void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() const
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000019{
Sadik Armaganef38d5d2019-03-25 09:03:35 +000020 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString));
Sadik Armaganef38d5d2019-03-25 09:03:35 +000021 const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
22 const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
23 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000024
Sadik Armaganef38d5d2019-03-25 09:03:35 +000025 const TensorShape& inShape0 = inputInfo0.GetShape();
26 const TensorShape& inShape1 = inputInfo1.GetShape();
27 const TensorShape& outShape = outputInfo.GetShape();
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000028
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010029 switch(inputInfo0.GetDataType())
Sadik Armaganef38d5d2019-03-25 09:03:35 +000030 {
31 case armnn::DataType::QuantisedAsymm8:
32 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010033 QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data),
34 inputInfo0.GetQuantizationScale(),
35 inputInfo0.GetQuantizationOffset());
36
37 QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data),
38 inputInfo1.GetQuantizationScale(),
39 inputInfo1.GetQuantizationOffset());
40
41 QASymm8Encoder encodeIterator0(GetOutputTensorDataU8(0, m_Data),
42 outputInfo.GetQuantizationScale(),
43 outputInfo.GetQuantizationOffset());
44
45 ElementwiseFunction<Functor, Decoder, Encoder>(inShape0,
46 inShape1,
47 outShape,
48 decodeIterator0,
49 decodeIterator1,
50 encodeIterator0);
Sadik Armaganef38d5d2019-03-25 09:03:35 +000051 break;
52 }
53 case armnn::DataType::Float32:
54 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010055 FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data));
56 FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data));
57 FloatEncoder encodeIterator0(GetOutputTensorDataFloat(0, m_Data));
58
59 ElementwiseFunction<Functor, Decoder, Encoder>(inShape0,
60 inShape1,
61 outShape,
62 decodeIterator0,
63 decodeIterator1,
64 encodeIterator0);
Sadik Armaganef38d5d2019-03-25 09:03:35 +000065 break;
66 }
67 default:
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010068 BOOST_ASSERT_MSG(false, "RefElementwiseWorkload: Not supported Data Type!");
Sadik Armaganef38d5d2019-03-25 09:03:35 +000069 break;
70 }
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000071}
72
73}
74
Sadik Armaganef38d5d2019-03-25 09:03:35 +000075template class armnn::RefElementwiseWorkload<std::plus<float>,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010076 armnn::AdditionQueueDescriptor,
77 armnn::StringMapping::RefAdditionWorkload_Execute>;
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000078
Sadik Armaganef38d5d2019-03-25 09:03:35 +000079template class armnn::RefElementwiseWorkload<std::minus<float>,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010080 armnn::SubtractionQueueDescriptor,
81 armnn::StringMapping::RefSubtractionWorkload_Execute>;
saoste012df12b32018-11-28 16:57:20 +000082
Sadik Armaganef38d5d2019-03-25 09:03:35 +000083template class armnn::RefElementwiseWorkload<std::multiplies<float>,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010084 armnn::MultiplicationQueueDescriptor,
85 armnn::StringMapping::RefMultiplicationWorkload_Execute>;
Sadik Armaganef38d5d2019-03-25 09:03:35 +000086
87template class armnn::RefElementwiseWorkload<std::divides<float>,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010088 armnn::DivisionQueueDescriptor,
89 armnn::StringMapping::RefDivisionWorkload_Execute>;
Sadik Armaganef38d5d2019-03-25 09:03:35 +000090
91template class armnn::RefElementwiseWorkload<armnn::maximum<float>,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010092 armnn::MaximumQueueDescriptor,
93 armnn::StringMapping::RefMaximumWorkload_Execute>;
Sadik Armaganef38d5d2019-03-25 09:03:35 +000094
95template class armnn::RefElementwiseWorkload<armnn::minimum<float>,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010096 armnn::MinimumQueueDescriptor,
97 armnn::StringMapping::RefMinimumWorkload_Execute>;