blob: 1a30e7c9fb6f7811973d0f29d2bada4bb5b4b900 [file] [log] [blame]
Éanna Ó Catháind57415d2018-11-28 16:24:38 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefElementwiseWorkload.hpp"
7#include "ElementwiseFunction.hpp"
8#include "RefWorkloadUtils.hpp"
9#include "Profiling.hpp"
Sadik Armaganef38d5d2019-03-25 09:03:35 +000010#include "StringMapping.hpp"
11#include "TypeUtils.hpp"
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000012#include <vector>
13
14namespace armnn
15{
16
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010017template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
18void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() const
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000019{
Sadik Armaganef38d5d2019-03-25 09:03:35 +000020 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString));
Sadik Armaganef38d5d2019-03-25 09:03:35 +000021 const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
22 const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
23 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000024
Sadik Armaganef38d5d2019-03-25 09:03:35 +000025 const TensorShape& inShape0 = inputInfo0.GetShape();
26 const TensorShape& inShape1 = inputInfo1.GetShape();
27 const TensorShape& outShape = outputInfo.GetShape();
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000028
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010029 switch(inputInfo0.GetDataType())
Sadik Armaganef38d5d2019-03-25 09:03:35 +000030 {
31 case armnn::DataType::QuantisedAsymm8:
32 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010033 QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data),
34 inputInfo0.GetQuantizationScale(),
35 inputInfo0.GetQuantizationOffset());
36
37 QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data),
38 inputInfo1.GetQuantizationScale(),
39 inputInfo1.GetQuantizationOffset());
40
41 QASymm8Encoder encodeIterator0(GetOutputTensorDataU8(0, m_Data),
42 outputInfo.GetQuantizationScale(),
43 outputInfo.GetQuantizationOffset());
44
45 ElementwiseFunction<Functor, Decoder, Encoder>(inShape0,
46 inShape1,
47 outShape,
48 decodeIterator0,
49 decodeIterator1,
50 encodeIterator0);
Sadik Armaganef38d5d2019-03-25 09:03:35 +000051 break;
52 }
53 case armnn::DataType::Float32:
54 {
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010055 FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data));
56 FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data));
57 FloatEncoder encodeIterator0(GetOutputTensorDataFloat(0, m_Data));
58
59 ElementwiseFunction<Functor, Decoder, Encoder>(inShape0,
60 inShape1,
61 outShape,
62 decodeIterator0,
63 decodeIterator1,
64 encodeIterator0);
Sadik Armaganef38d5d2019-03-25 09:03:35 +000065 break;
66 }
Sadik Armagan2999a022019-04-09 14:20:12 +010067 case armnn::DataType::QuantisedSymm16:
68 {
69 QSymm16Decoder decodeIterator0(GetInputTensorData<int16_t>(0, m_Data),
70 inputInfo0.GetQuantizationScale(),
71 inputInfo0.GetQuantizationOffset());
72
73 QSymm16Decoder decodeIterator1(GetInputTensorData<int16_t>(1, m_Data),
74 inputInfo1.GetQuantizationScale(),
75 inputInfo1.GetQuantizationOffset());
76
77 QSymm16Encoder encodeIterator0(GetOutputTensorData<int16_t>(0, m_Data),
78 outputInfo.GetQuantizationScale(),
79 outputInfo.GetQuantizationOffset());
80
81 ElementwiseFunction<Functor, Decoder, Encoder>(inShape0,
82 inShape1,
83 outShape,
84 decodeIterator0,
85 decodeIterator1,
86 encodeIterator0);
87 break;
88 }
Sadik Armaganef38d5d2019-03-25 09:03:35 +000089 default:
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010090 BOOST_ASSERT_MSG(false, "RefElementwiseWorkload: Not supported Data Type!");
Sadik Armaganef38d5d2019-03-25 09:03:35 +000091 break;
92 }
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000093}
94
95}
96
Sadik Armaganef38d5d2019-03-25 09:03:35 +000097template class armnn::RefElementwiseWorkload<std::plus<float>,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +010098 armnn::AdditionQueueDescriptor,
99 armnn::StringMapping::RefAdditionWorkload_Execute>;
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000100
Sadik Armaganef38d5d2019-03-25 09:03:35 +0000101template class armnn::RefElementwiseWorkload<std::minus<float>,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100102 armnn::SubtractionQueueDescriptor,
103 armnn::StringMapping::RefSubtractionWorkload_Execute>;
saoste012df12b32018-11-28 16:57:20 +0000104
Sadik Armaganef38d5d2019-03-25 09:03:35 +0000105template class armnn::RefElementwiseWorkload<std::multiplies<float>,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100106 armnn::MultiplicationQueueDescriptor,
107 armnn::StringMapping::RefMultiplicationWorkload_Execute>;
Sadik Armaganef38d5d2019-03-25 09:03:35 +0000108
109template class armnn::RefElementwiseWorkload<std::divides<float>,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100110 armnn::DivisionQueueDescriptor,
111 armnn::StringMapping::RefDivisionWorkload_Execute>;
Sadik Armaganef38d5d2019-03-25 09:03:35 +0000112
113template class armnn::RefElementwiseWorkload<armnn::maximum<float>,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100114 armnn::MaximumQueueDescriptor,
115 armnn::StringMapping::RefMaximumWorkload_Execute>;
Sadik Armaganef38d5d2019-03-25 09:03:35 +0000116
117template class armnn::RefElementwiseWorkload<armnn::minimum<float>,
Sadik Armagan2e6dc3a2019-04-03 17:48:18 +0100118 armnn::MinimumQueueDescriptor,
119 armnn::StringMapping::RefMinimumWorkload_Execute>;