blob: 0deff79dac78f6a035034056ec817b7eddb2db28 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
arovir019e53a352018-08-31 15:26:35 +01006#include "NeonNormalizationFloatWorkload.hpp"
David Beck0dbe0ee2018-09-24 15:59:27 +01007#include <backends/neon/NeonLayerSupport.hpp>
David Beck711fa312018-09-24 10:46:38 +01008#include <backends/aclCommon/ArmComputeUtils.hpp>
9#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010
narpra0133cea4d2018-09-27 16:46:14 +010011using namespace armnn::armcomputetensorutils;
12
telsoa014fcda012018-03-09 14:13:49 +000013namespace armnn
14{
15
telsoa01c577f2c2018-08-31 09:22:23 +010016arm_compute::Status NeonNormalizationWorkloadValidate(const TensorInfo& input,
17 const TensorInfo& output,
18 const NormalizationDescriptor& descriptor)
19{
narpra0133cea4d2018-09-27 16:46:14 +010020 const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
21 const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
telsoa01c577f2c2018-08-31 09:22:23 +010022
narpra0133cea4d2018-09-27 16:46:14 +010023 arm_compute::NormalizationLayerInfo normalizationInfo = BuildArmComputeNormalizationLayerInfo(descriptor);
telsoa01c577f2c2018-08-31 09:22:23 +010024
25 return arm_compute::NENormalizationLayer::validate(&aclInput, &aclOutput, normalizationInfo);
26}
27
arovir019e53a352018-08-31 15:26:35 +010028NeonNormalizationFloatWorkload::NeonNormalizationFloatWorkload(const NormalizationQueueDescriptor& descriptor,
telsoa01c577f2c2018-08-31 09:22:23 +010029 const WorkloadInfo& info,
30 std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
31 : FloatWorkload<NormalizationQueueDescriptor>(descriptor, info)
surmeh013537c2c2018-05-18 16:31:43 +010032 , m_NormalizationLayer(memoryManager)
telsoa014fcda012018-03-09 14:13:49 +000033{
arovir019e53a352018-08-31 15:26:35 +010034 m_Data.ValidateInputsOutputs("NeonNormalizationFloatWorkload", 1, 1);
telsoa014fcda012018-03-09 14:13:49 +000035 std::string reasonIfUnsupported;
arovir01085f0a42018-10-08 14:48:19 +010036 if (!IsNeonNormalizationDescParamsSupported(Optional<std::string&>(reasonIfUnsupported), m_Data.m_Parameters))
telsoa014fcda012018-03-09 14:13:49 +000037 {
38 throw UnimplementedException(reasonIfUnsupported);
39 }
40
telsoa01c577f2c2018-08-31 09:22:23 +010041 // Input and output tensors have to have the same dimensionality.
telsoa014fcda012018-03-09 14:13:49 +000042 if (info.m_InputTensorInfos[0].GetShape()[1] != info.m_OutputTensorInfos[0].GetShape()[1]
43 || info.m_InputTensorInfos[0].GetShape()[0] != info.m_OutputTensorInfos[0].GetShape()[0]
44 || info.m_InputTensorInfos[0].GetShape()[3] != info.m_OutputTensorInfos[0].GetShape()[3]
45 || info.m_InputTensorInfos[0].GetShape()[2] != info.m_OutputTensorInfos[0].GetShape()[2])
46 {
47 throw InvalidArgumentException("Normalization requires input and output tensors to have equal dimensionality.");
48 }
49
50 arm_compute::ITensor& input = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
51 arm_compute::ITensor& output = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
narpra0155a97bc2018-10-02 14:35:53 +010052 arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
53 input.info()->set_data_layout(aclDataLayout);
54 output.info()->set_data_layout(aclDataLayout);
telsoa014fcda012018-03-09 14:13:49 +000055
56 const arm_compute::NormType normType =
57 ConvertNormalizationAlgorithmChannelToAclNormType(m_Data.m_Parameters.m_NormChannelType);
58 arm_compute::NormalizationLayerInfo normalizationInfo(normType,
59 m_Data.m_Parameters.m_NormSize,
60 m_Data.m_Parameters.m_Alpha,
61 m_Data.m_Parameters.m_Beta,
62 m_Data.m_Parameters.m_K,
63 false);
64
65 m_NormalizationLayer.configure(&input, &output, normalizationInfo);
66}
67
arovir019e53a352018-08-31 15:26:35 +010068void NeonNormalizationFloatWorkload::Execute() const
telsoa014fcda012018-03-09 14:13:49 +000069{
arovir019e53a352018-08-31 15:26:35 +010070 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonNormalizationFloatWorkload_Execute");
telsoa014fcda012018-03-09 14:13:49 +000071 m_NormalizationLayer.run();
72}
73
74} //namespace armnn