blob: e71cdd4e3c99c889ccc6f49a779b612c3c41fc53 [file] [log] [blame]
Mike Kelly3ec30772023-03-08 13:47:17 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefElementwiseBinaryWorkload.hpp"
7
8#include "Decoders.hpp"
9#include "ElementwiseFunction.hpp"
10#include "Encoders.hpp"
11#include "RefWorkloadUtils.hpp"
12#include "Maximum.hpp"
13#include "Minimum.hpp"
John Mcloughlin0ec00872023-05-15 17:03:49 +010014#include "SquaredDifference.hpp"
15#include "Power.hpp"
Mike Kelly3ec30772023-03-08 13:47:17 +000016
17#include <Profiling.hpp>
18
19#include <armnn/TypesUtils.hpp>
20
21#include <functional>
22
23namespace armnn
24{
25
26template<typename DataType>
27void ExecuteFunction(std::vector<ITensorHandle*> inputs,
28 std::vector<ITensorHandle*> outputs,
29 BinaryOperation operation)
30{
31 const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]);
32 const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]);
33 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
34
35 const TensorShape& inShape0 = inputInfo0.GetShape();
36 const TensorShape& inShape1 = inputInfo1.GetShape();
37 const TensorShape& outShape = outputInfo.GetShape();
38
39 std::unique_ptr<Decoder<DataType>> input0 = MakeDecoder<DataType>(inputInfo0, inputs[0]->Map());
40 std::unique_ptr<Decoder<DataType>> input1 = MakeDecoder<DataType>(inputInfo1, inputs[1]->Map());
41 std::unique_ptr<Encoder<DataType>> output = MakeEncoder<DataType>(outputInfo, outputs[0]->Map());
42
43 using AddFunction = ElementwiseBinaryFunction<std::plus<DataType>>;
44 using DivFunction = ElementwiseBinaryFunction<std::divides<DataType>>;
45 using MaximumFunction = ElementwiseBinaryFunction<armnn::maximum<DataType>>;
46 using MinimumFunction = ElementwiseBinaryFunction<armnn::minimum<DataType>>;
47 using MulFunction = ElementwiseBinaryFunction<std::multiplies<DataType>>;
48 using SubFunction = ElementwiseBinaryFunction<std::minus<DataType>>;
John Mcloughlin0ec00872023-05-15 17:03:49 +010049 using SqDiffFunction = ElementwiseBinaryFunction<armnn::squaredDifference<DataType>>;
50 using PowerFunction = ElementwiseBinaryFunction<armnn::power<DataType>>;
Mike Kelly3ec30772023-03-08 13:47:17 +000051
52 switch (operation)
53 {
54 case BinaryOperation::Add:
55 {
56 AddFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
57 break;
58 }
59 case BinaryOperation::Div:
60 {
61 DivFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
62 break;
63 }
64 case BinaryOperation::Maximum:
65 {
66 MaximumFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
67 break;
68 }
69 case BinaryOperation::Minimum:
70 {
71 MinimumFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
72 break;
73 }
74 case BinaryOperation::Mul:
75 {
76 MulFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
77 break;
78 }
79 case BinaryOperation::Sub:
80 {
81 SubFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
82 break;
83 }
John Mcloughlin0ec00872023-05-15 17:03:49 +010084 case BinaryOperation::SqDiff:
85 {
86 SqDiffFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
87 break;
88 }
89 case BinaryOperation::Power:
90 {
91 PowerFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
92 break;
93 }
Mike Kelly3ec30772023-03-08 13:47:17 +000094 default:
95 {
96 throw InvalidArgumentException(std::string("Unsupported binary operation ") +
97 GetBinaryOperationAsCString(operation), CHECK_LOCATION());
98 }
99 }
100}
101
102RefElementwiseBinaryWorkload::RefElementwiseBinaryWorkload(const ElementwiseBinaryQueueDescriptor& desc,
103 const WorkloadInfo& info)
104 : RefBaseWorkload<ElementwiseBinaryQueueDescriptor>(desc, info)
105{}
106
107void RefElementwiseBinaryWorkload::Execute() const
108{
109 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
110}
111
112void RefElementwiseBinaryWorkload::ExecuteAsync(ExecutionData& executionData)
113{
114
115 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
116 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
117}
118
119void RefElementwiseBinaryWorkload::Execute(std::vector<ITensorHandle*> inputs,
120 std::vector<ITensorHandle*> outputs) const
121{
122 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefElementwiseBinaryWorkload_Execute");
123
124 if (GetTensorInfo(inputs[0]).GetDataType() == DataType::Signed32)
125 {
126 ExecuteFunction<int32_t>(inputs, outputs, m_Data.m_Parameters.m_Operation);
127 }
128 else
129 {
130 ExecuteFunction<float>(inputs, outputs, m_Data.m_Parameters.m_Operation);
131 }
132}
133
134} // namespace armnn