blob: 5dc77f849673f5b51725b6a2d310767e6f01ecdd [file] [log] [blame]
Mike Kelly3ec30772023-03-08 13:47:17 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefElementwiseBinaryWorkload.hpp"
7
8#include "Decoders.hpp"
9#include "ElementwiseFunction.hpp"
10#include "Encoders.hpp"
11#include "RefWorkloadUtils.hpp"
12#include "Maximum.hpp"
13#include "Minimum.hpp"
14
15#include <Profiling.hpp>
16
17#include <armnn/TypesUtils.hpp>
18
19#include <functional>
20
21namespace armnn
22{
23
24template<typename DataType>
25void ExecuteFunction(std::vector<ITensorHandle*> inputs,
26 std::vector<ITensorHandle*> outputs,
27 BinaryOperation operation)
28{
29 const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]);
30 const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]);
31 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
32
33 const TensorShape& inShape0 = inputInfo0.GetShape();
34 const TensorShape& inShape1 = inputInfo1.GetShape();
35 const TensorShape& outShape = outputInfo.GetShape();
36
37 std::unique_ptr<Decoder<DataType>> input0 = MakeDecoder<DataType>(inputInfo0, inputs[0]->Map());
38 std::unique_ptr<Decoder<DataType>> input1 = MakeDecoder<DataType>(inputInfo1, inputs[1]->Map());
39 std::unique_ptr<Encoder<DataType>> output = MakeEncoder<DataType>(outputInfo, outputs[0]->Map());
40
41 using AddFunction = ElementwiseBinaryFunction<std::plus<DataType>>;
42 using DivFunction = ElementwiseBinaryFunction<std::divides<DataType>>;
43 using MaximumFunction = ElementwiseBinaryFunction<armnn::maximum<DataType>>;
44 using MinimumFunction = ElementwiseBinaryFunction<armnn::minimum<DataType>>;
45 using MulFunction = ElementwiseBinaryFunction<std::multiplies<DataType>>;
46 using SubFunction = ElementwiseBinaryFunction<std::minus<DataType>>;
47
48 switch (operation)
49 {
50 case BinaryOperation::Add:
51 {
52 AddFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
53 break;
54 }
55 case BinaryOperation::Div:
56 {
57 DivFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
58 break;
59 }
60 case BinaryOperation::Maximum:
61 {
62 MaximumFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
63 break;
64 }
65 case BinaryOperation::Minimum:
66 {
67 MinimumFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
68 break;
69 }
70 case BinaryOperation::Mul:
71 {
72 MulFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
73 break;
74 }
75 case BinaryOperation::Sub:
76 {
77 SubFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
78 break;
79 }
80 default:
81 {
82 throw InvalidArgumentException(std::string("Unsupported binary operation ") +
83 GetBinaryOperationAsCString(operation), CHECK_LOCATION());
84 }
85 }
86}
87
88RefElementwiseBinaryWorkload::RefElementwiseBinaryWorkload(const ElementwiseBinaryQueueDescriptor& desc,
89 const WorkloadInfo& info)
90 : RefBaseWorkload<ElementwiseBinaryQueueDescriptor>(desc, info)
91{}
92
93void RefElementwiseBinaryWorkload::Execute() const
94{
95 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
96}
97
98void RefElementwiseBinaryWorkload::ExecuteAsync(ExecutionData& executionData)
99{
100
101 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
102 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
103}
104
105void RefElementwiseBinaryWorkload::Execute(std::vector<ITensorHandle*> inputs,
106 std::vector<ITensorHandle*> outputs) const
107{
108 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefElementwiseBinaryWorkload_Execute");
109
110 if (GetTensorInfo(inputs[0]).GetDataType() == DataType::Signed32)
111 {
112 ExecuteFunction<int32_t>(inputs, outputs, m_Data.m_Parameters.m_Operation);
113 }
114 else
115 {
116 ExecuteFunction<float>(inputs, outputs, m_Data.m_Parameters.m_Operation);
117 }
118}
119
120} // namespace armnn