blob: b442f25c2a3907560d52c92f3fd9b738ae9c4e18 [file] [log] [blame]
josh minor4a3c6102020-01-06 16:40:46 -06001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefElementwiseUnaryWorkload.hpp"
7
8#include "Decoders.hpp"
9#include "ElementwiseFunction.hpp"
10#include "Encoders.hpp"
11#include "RefWorkloadUtils.hpp"
12#include "Abs.hpp"
13#include "Exp.hpp"
14#include "Rsqrt.hpp"
15#include "Sqrt.hpp"
16
17#include <Profiling.hpp>
18
19#include <armnn/TypesUtils.hpp>
20
21#include <functional>
22
23namespace armnn
24{
25
26RefElementwiseUnaryWorkload::RefElementwiseUnaryWorkload(const ElementwiseUnaryQueueDescriptor& desc,
27 const WorkloadInfo& info)
28 : BaseWorkload<ElementwiseUnaryQueueDescriptor>(desc, info)
29{}
30
Finn Williamsb8181f72021-04-07 10:23:21 +010031void RefElementwiseUnaryWorkload::Execute() const
josh minor4a3c6102020-01-06 16:40:46 -060032{
Finn Williamsb8181f72021-04-07 10:23:21 +010033 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
josh minor4a3c6102020-01-06 16:40:46 -060034}
35
Finn Williamsb8181f72021-04-07 10:23:21 +010036void RefElementwiseUnaryWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor)
37{
38
39 Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
40}
41
42void RefElementwiseUnaryWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
josh minor4a3c6102020-01-06 16:40:46 -060043{
44 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefElementwiseUnaryWorkload_Execute");
45
Finn Williamsb8181f72021-04-07 10:23:21 +010046 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
47 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
josh minor4a3c6102020-01-06 16:40:46 -060048
49 const TensorShape& inShape = inputInfo.GetShape();
50 const TensorShape& outShape = outputInfo.GetShape();
51
Finn Williamsb8181f72021-04-07 10:23:21 +010052 std::unique_ptr<Decoder<InType>> input = MakeDecoder<InType>(inputInfo, inputs[0]->Map());
53 std::unique_ptr<Encoder<OutType>> output= MakeEncoder<OutType>(outputInfo, outputs[0]->Map());
josh minor4a3c6102020-01-06 16:40:46 -060054
55 using AbsFunction = ElementwiseUnaryFunction<abs<InType>>;
56 using ExpFunction = ElementwiseUnaryFunction<exp<InType>>;
57 using NegFunction = ElementwiseUnaryFunction<std::negate<InType>>;
58 using RsqrtFunction = ElementwiseUnaryFunction<rsqrt<InType>>;
59 using SqrtFunction = ElementwiseUnaryFunction<sqrt<InType>>;
60
61 switch (m_Data.m_Parameters.m_Operation)
62 {
63 case UnaryOperation::Abs:
64 {
Finn Williamsb8181f72021-04-07 10:23:21 +010065 AbsFunction(inShape, outShape, *input, *output);
josh minor4a3c6102020-01-06 16:40:46 -060066 break;
67 }
68 case UnaryOperation::Exp:
69 {
Finn Williamsb8181f72021-04-07 10:23:21 +010070 ExpFunction(inShape, outShape, *input, *output);
josh minor4a3c6102020-01-06 16:40:46 -060071 break;
72 }
73 case UnaryOperation::Neg:
74 {
Finn Williamsb8181f72021-04-07 10:23:21 +010075 NegFunction(inShape, outShape, *input, *output);
josh minor4a3c6102020-01-06 16:40:46 -060076 break;
77 }
78 case UnaryOperation::Rsqrt:
79 {
Finn Williamsb8181f72021-04-07 10:23:21 +010080 RsqrtFunction(inShape, outShape, *input, *output);
josh minor4a3c6102020-01-06 16:40:46 -060081 break;
82 }
83 case UnaryOperation::Sqrt:
84 {
Finn Williamsb8181f72021-04-07 10:23:21 +010085 SqrtFunction(inShape, outShape, *input, *output);
josh minor4a3c6102020-01-06 16:40:46 -060086 break;
87 }
88 default:
89 {
90 throw InvalidArgumentException(std::string("Unsupported unary operation ") +
91 GetUnaryOperationAsCString(m_Data.m_Parameters.m_Operation), CHECK_LOCATION());
92 }
93 }
94}
95
96} // namespace armnn