blob: 60a1b990f77522745cb0855f71a3aa97bd32b239 [file] [log] [blame]
Éanna Ó Catháind57415d2018-11-28 16:24:38 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefElementwiseWorkload.hpp"
7#include "ElementwiseFunction.hpp"
8#include "RefWorkloadUtils.hpp"
9#include "Profiling.hpp"
10#include <vector>
11
12namespace armnn
13{
14
15template <typename ParentDescriptor, typename Functor>
16void BaseFloat32ElementwiseWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char * debugString) const
17{
18 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString);
19
20 auto data = Float32Workload<ParentDescriptor>::GetData();
21 const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape();
22 const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape();
23 const TensorShape& outShape = GetTensorInfo(data.m_Outputs[0]).GetShape();
24
25 const float* inData0 = GetInputTensorDataFloat(0, data);
26 const float* inData1 = GetInputTensorDataFloat(1, data);
27 float* outData = GetOutputTensorDataFloat(0, data);
28
29 ElementwiseFunction<Functor>(inShape0, inShape1, outShape, inData0, inData1, outData);
30}
31
32template <typename ParentDescriptor, typename Functor>
33void BaseUint8ElementwiseWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char * debugString) const
34{
35 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString);
36
37 auto data = Uint8Workload<ParentDescriptor>::GetData();
38 const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]);
39 const TensorInfo& inputInfo1 = GetTensorInfo(data.m_Inputs[1]);
40 const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[0]);
41
42 auto dequant0 = Dequantize(GetInputTensorDataU8(0, data), inputInfo0);
43 auto dequant1 = Dequantize(GetInputTensorDataU8(1, data), inputInfo1);
44
45 std::vector<float> results(outputInfo.GetNumElements());
46
47 ElementwiseFunction<Functor>(inputInfo0.GetShape(),
48 inputInfo1.GetShape(),
49 outputInfo.GetShape(),
50 dequant0.data(),
51 dequant1.data(),
52 results.data());
53
54 Quantize(GetOutputTensorDataU8(0, data), results.data(), outputInfo);
55}
56
57}
58
59template class armnn::BaseFloat32ElementwiseWorkload<armnn::AdditionQueueDescriptor, std::plus<float>>;
60template class armnn::BaseUint8ElementwiseWorkload<armnn::AdditionQueueDescriptor, std::plus<float>>;
61
62template class armnn::BaseFloat32ElementwiseWorkload<armnn::SubtractionQueueDescriptor, std::minus<float>>;
63template class armnn::BaseUint8ElementwiseWorkload<armnn::SubtractionQueueDescriptor, std::minus<float>>;
64
65template class armnn::BaseFloat32ElementwiseWorkload<armnn::MultiplicationQueueDescriptor, std::multiplies<float>>;
66template class armnn::BaseUint8ElementwiseWorkload<armnn::MultiplicationQueueDescriptor, std::multiplies<float>>;
67
68template class armnn::BaseFloat32ElementwiseWorkload<armnn::DivisionQueueDescriptor, std::divides<float>>;
69template class armnn::BaseUint8ElementwiseWorkload<armnn::DivisionQueueDescriptor, std::divides<float>>;
saoste012df12b32018-11-28 16:57:20 +000070
71template class armnn::BaseFloat32ElementwiseWorkload<armnn::MaximumQueueDescriptor, armnn::maximum<float>>;
72template class armnn::BaseUint8ElementwiseWorkload<armnn::MaximumQueueDescriptor, armnn::maximum<float>>;