blob: ff777dbf9b6d24cdeaf11ac596d8f99d9ffae0d5 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
Matthew Benthamc48ac8c2018-12-12 16:15:59 +00006#include "NeonBatchNormalizationWorkload.hpp"
Matthew Benthamd80a7122019-01-08 17:52:37 +00007
8#include "NeonWorkloadUtils.hpp"
9
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <aclCommon/ArmComputeTensorUtils.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010011#include <armnn/utility/PolymorphicDowncast.hpp>
12#include <backendsCommon/CpuTensorHandle.hpp>
Matthew Benthamd80a7122019-01-08 17:52:37 +000013
14#include <arm_compute/runtime/NEON/functions/NEBatchNormalizationLayer.h>
telsoa014fcda012018-03-09 14:13:49 +000015
16namespace armnn
17{
18using namespace armcomputetensorutils;
19
telsoa01c577f2c2018-08-31 09:22:23 +010020
21arm_compute::Status NeonBatchNormalizationValidate(const TensorInfo& input,
22 const TensorInfo& output,
23 const TensorInfo& mean,
24 const TensorInfo& var,
25 const TensorInfo& beta,
26 const TensorInfo& gamma,
27 const BatchNormalizationDescriptor& descriptor)
28{
Nikhil Rajd1340932018-10-18 14:27:50 +010029 const arm_compute::TensorInfo aclInputInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000030 armcomputetensorutils::BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010031 const arm_compute::TensorInfo aclOutputInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000032 armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010033 const arm_compute::TensorInfo aclMeanInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000034 armcomputetensorutils::BuildArmComputeTensorInfo(mean, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010035 const arm_compute::TensorInfo aclVarInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000036 armcomputetensorutils::BuildArmComputeTensorInfo(var, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010037 const arm_compute::TensorInfo aclBetaInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000038 armcomputetensorutils::BuildArmComputeTensorInfo(beta, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010039 const arm_compute::TensorInfo aclGammaInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000040 armcomputetensorutils::BuildArmComputeTensorInfo(gamma, descriptor.m_DataLayout);
telsoa01c577f2c2018-08-31 09:22:23 +010041
42 return arm_compute::NEBatchNormalizationLayer::validate(&aclInputInfo,
43 &aclOutputInfo,
44 &aclMeanInfo,
45 &aclVarInfo,
46 &aclBetaInfo,
47 &aclGammaInfo,
48 descriptor.m_Eps);
49}
50
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000051NeonBatchNormalizationWorkload::NeonBatchNormalizationWorkload(
telsoa014fcda012018-03-09 14:13:49 +000052 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info)
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000053 : BaseWorkload<BatchNormalizationQueueDescriptor>(descriptor, info)
telsoa014fcda012018-03-09 14:13:49 +000054{
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000055 m_Data.ValidateInputsOutputs("NeonBatchNormalizationWorkload", 1, 1);
telsoa014fcda012018-03-09 14:13:49 +000056
Jan Eilersbb446e52020-04-02 13:56:54 +010057 arm_compute::ITensor& input = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
58 arm_compute::ITensor& output = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
telsoa014fcda012018-03-09 14:13:49 +000059
Matthew Bentham8800c002018-11-19 13:19:28 +000060 arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010061 input.info()->set_data_layout(aclDataLayout);
62 output.info()->set_data_layout(aclDataLayout);
63
telsoa01c577f2c2018-08-31 09:22:23 +010064 m_Mean = std::make_unique<arm_compute::Tensor>();
65 BuildArmComputeTensor(*m_Mean, m_Data.m_Mean->GetTensorInfo());
telsoa014fcda012018-03-09 14:13:49 +000066
telsoa01c577f2c2018-08-31 09:22:23 +010067 m_Variance = std::make_unique<arm_compute::Tensor>();
68 BuildArmComputeTensor(*m_Variance, m_Data.m_Variance->GetTensorInfo());
telsoa014fcda012018-03-09 14:13:49 +000069
telsoa01c577f2c2018-08-31 09:22:23 +010070 m_Gamma = std::make_unique<arm_compute::Tensor>();
71 BuildArmComputeTensor(*m_Gamma, m_Data.m_Gamma->GetTensorInfo());
72
73 m_Beta = std::make_unique<arm_compute::Tensor>();
74 BuildArmComputeTensor(*m_Beta, m_Data.m_Beta->GetTensorInfo());
75
Matthew Benthamd80a7122019-01-08 17:52:37 +000076 auto layer = std::make_unique<arm_compute::NEBatchNormalizationLayer>();
77 layer->configure(&input,
78 &output,
79 m_Mean.get(),
80 m_Variance.get(),
81 m_Beta.get(),
82 m_Gamma.get(),
83 m_Data.m_Parameters.m_Eps);
84 m_Layer.reset(layer.release());
telsoa01c577f2c2018-08-31 09:22:23 +010085
Nattapat Chaimanowong177d8d22018-10-16 13:21:27 +010086 InitializeArmComputeTensorData(*m_Mean, m_Data.m_Mean);
87 InitializeArmComputeTensorData(*m_Variance, m_Data.m_Variance);
88 InitializeArmComputeTensorData(*m_Gamma, m_Data.m_Gamma);
89 InitializeArmComputeTensorData(*m_Beta, m_Data.m_Beta);
telsoa01c577f2c2018-08-31 09:22:23 +010090
91 // Force Compute Library to perform the necessary copying and reshaping, after which
92 // delete all the input tensors that will no longer be needed
Matthew Benthamd80a7122019-01-08 17:52:37 +000093 m_Layer->prepare();
telsoa01c577f2c2018-08-31 09:22:23 +010094 FreeUnusedTensors();
telsoa014fcda012018-03-09 14:13:49 +000095}
96
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000097void NeonBatchNormalizationWorkload::Execute() const
telsoa014fcda012018-03-09 14:13:49 +000098{
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000099 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonBatchNormalizationWorkload_Execute");
Matthew Benthamd80a7122019-01-08 17:52:37 +0000100 m_Layer->run();
telsoa014fcda012018-03-09 14:13:49 +0000101}
102
Matthew Benthamc48ac8c2018-12-12 16:15:59 +0000103void NeonBatchNormalizationWorkload::FreeUnusedTensors()
telsoa01c577f2c2018-08-31 09:22:23 +0100104{
105 FreeTensorIfUnused(m_Mean);
106 FreeTensorIfUnused(m_Variance);
107 FreeTensorIfUnused(m_Gamma);
108 FreeTensorIfUnused(m_Beta);
109}
110
telsoa014fcda012018-03-09 14:13:49 +0000111} //namespace armnn