blob: 4fbb0d123f6fde0914cd8c2a20b913c976d25626 [file] [log] [blame]
josh minor4a3c6102020-01-06 16:40:46 -06001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefElementwiseUnaryWorkload.hpp"
7
8#include "Decoders.hpp"
9#include "ElementwiseFunction.hpp"
10#include "Encoders.hpp"
11#include "RefWorkloadUtils.hpp"
12#include "Abs.hpp"
13#include "Exp.hpp"
14#include "Rsqrt.hpp"
15#include "Sqrt.hpp"
16
17#include <Profiling.hpp>
18
19#include <armnn/TypesUtils.hpp>
20
21#include <functional>
22
23namespace armnn
24{
25
26RefElementwiseUnaryWorkload::RefElementwiseUnaryWorkload(const ElementwiseUnaryQueueDescriptor& desc,
27 const WorkloadInfo& info)
28 : BaseWorkload<ElementwiseUnaryQueueDescriptor>(desc, info)
29{}
30
31void RefElementwiseUnaryWorkload::PostAllocationConfigure()
32{
33 const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
34 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
35
36 m_Input = MakeDecoder<InType>(inputInfo);
37
38 m_Output = MakeEncoder<OutType>(outputInfo);
39}
40
41void RefElementwiseUnaryWorkload::Execute() const
42{
43 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefElementwiseUnaryWorkload_Execute");
44
45 const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
46 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
47
48 const TensorShape& inShape = inputInfo.GetShape();
49 const TensorShape& outShape = outputInfo.GetShape();
50
51 m_Input->Reset(m_Data.m_Inputs[0]->Map());
52 m_Output->Reset(m_Data.m_Outputs[0]->Map());
53
54 using AbsFunction = ElementwiseUnaryFunction<abs<InType>>;
55 using ExpFunction = ElementwiseUnaryFunction<exp<InType>>;
56 using NegFunction = ElementwiseUnaryFunction<std::negate<InType>>;
57 using RsqrtFunction = ElementwiseUnaryFunction<rsqrt<InType>>;
58 using SqrtFunction = ElementwiseUnaryFunction<sqrt<InType>>;
59
60 switch (m_Data.m_Parameters.m_Operation)
61 {
62 case UnaryOperation::Abs:
63 {
64 AbsFunction(inShape, outShape, *m_Input, *m_Output);
65 break;
66 }
67 case UnaryOperation::Exp:
68 {
69 ExpFunction(inShape, outShape, *m_Input, *m_Output);
70 break;
71 }
72 case UnaryOperation::Neg:
73 {
74 NegFunction(inShape, outShape, *m_Input, *m_Output);
75 break;
76 }
77 case UnaryOperation::Rsqrt:
78 {
79 RsqrtFunction(inShape, outShape, *m_Input, *m_Output);
80 break;
81 }
82 case UnaryOperation::Sqrt:
83 {
84 SqrtFunction(inShape, outShape, *m_Input, *m_Output);
85 break;
86 }
87 default:
88 {
89 throw InvalidArgumentException(std::string("Unsupported unary operation ") +
90 GetUnaryOperationAsCString(m_Data.m_Parameters.m_Operation), CHECK_LOCATION());
91 }
92 }
93}
94
95} // namespace armnn