blob: 0fd0dcc4206a139dc4a24c024528d310aa2ff632 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include "NeonNormalizationFloat32Workload.hpp"
7#include "backends/NeonLayerSupport.hpp"
8#include "backends/ArmComputeUtils.hpp"
9
10namespace armnn
11{
12
13NeonNormalizationFloat32Workload::NeonNormalizationFloat32Workload(const NormalizationQueueDescriptor& descriptor,
surmeh013537c2c2018-05-18 16:31:43 +010014 const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
telsoa014fcda012018-03-09 14:13:49 +000015 : Float32Workload<NormalizationQueueDescriptor>(descriptor, info)
surmeh013537c2c2018-05-18 16:31:43 +010016 , m_NormalizationLayer(memoryManager)
telsoa014fcda012018-03-09 14:13:49 +000017{
18 m_Data.ValidateInputsOutputs("NeonNormalizationFloat32Workload", 1, 1);
19 std::string reasonIfUnsupported;
20 if (!IsNeonNormalizationDescParamsSupported(&reasonIfUnsupported, m_Data.m_Parameters))
21 {
22 throw UnimplementedException(reasonIfUnsupported);
23 }
24
25 // input and output tensors have to have the same dimensionality
26 if (info.m_InputTensorInfos[0].GetShape()[1] != info.m_OutputTensorInfos[0].GetShape()[1]
27 || info.m_InputTensorInfos[0].GetShape()[0] != info.m_OutputTensorInfos[0].GetShape()[0]
28 || info.m_InputTensorInfos[0].GetShape()[3] != info.m_OutputTensorInfos[0].GetShape()[3]
29 || info.m_InputTensorInfos[0].GetShape()[2] != info.m_OutputTensorInfos[0].GetShape()[2])
30 {
31 throw InvalidArgumentException("Normalization requires input and output tensors to have equal dimensionality.");
32 }
33
34 arm_compute::ITensor& input = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
35 arm_compute::ITensor& output = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
36
37 const arm_compute::NormType normType =
38 ConvertNormalizationAlgorithmChannelToAclNormType(m_Data.m_Parameters.m_NormChannelType);
39 arm_compute::NormalizationLayerInfo normalizationInfo(normType,
40 m_Data.m_Parameters.m_NormSize,
41 m_Data.m_Parameters.m_Alpha,
42 m_Data.m_Parameters.m_Beta,
43 m_Data.m_Parameters.m_K,
44 false);
45
46 m_NormalizationLayer.configure(&input, &output, normalizationInfo);
47}
48
49void NeonNormalizationFloat32Workload::Execute() const
50{
51 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuAcc, "NeonNormalizationFloat32Workload_Execute");
52 m_NormalizationLayer.run();
53}
54
55} //namespace armnn