blob: cd931e3797c26470a2677f9dc4b868df023232cc [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
Matthew Benthamc48ac8c2018-12-12 16:15:59 +00006#include "NeonBatchNormalizationWorkload.hpp"
Matthew Benthamd80a7122019-01-08 17:52:37 +00007
8#include "NeonWorkloadUtils.hpp"
9
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <backendsCommon/CpuTensorHandle.hpp>
11#include <aclCommon/ArmComputeTensorUtils.hpp>
Matthew Benthamd80a7122019-01-08 17:52:37 +000012
13#include <arm_compute/runtime/NEON/functions/NEBatchNormalizationLayer.h>
telsoa014fcda012018-03-09 14:13:49 +000014
15namespace armnn
16{
17using namespace armcomputetensorutils;
18
telsoa01c577f2c2018-08-31 09:22:23 +010019
20arm_compute::Status NeonBatchNormalizationValidate(const TensorInfo& input,
21 const TensorInfo& output,
22 const TensorInfo& mean,
23 const TensorInfo& var,
24 const TensorInfo& beta,
25 const TensorInfo& gamma,
26 const BatchNormalizationDescriptor& descriptor)
27{
Nikhil Rajd1340932018-10-18 14:27:50 +010028 const arm_compute::TensorInfo aclInputInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000029 armcomputetensorutils::BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010030 const arm_compute::TensorInfo aclOutputInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000031 armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010032 const arm_compute::TensorInfo aclMeanInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000033 armcomputetensorutils::BuildArmComputeTensorInfo(mean, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010034 const arm_compute::TensorInfo aclVarInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000035 armcomputetensorutils::BuildArmComputeTensorInfo(var, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010036 const arm_compute::TensorInfo aclBetaInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000037 armcomputetensorutils::BuildArmComputeTensorInfo(beta, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010038 const arm_compute::TensorInfo aclGammaInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000039 armcomputetensorutils::BuildArmComputeTensorInfo(gamma, descriptor.m_DataLayout);
telsoa01c577f2c2018-08-31 09:22:23 +010040
41 return arm_compute::NEBatchNormalizationLayer::validate(&aclInputInfo,
42 &aclOutputInfo,
43 &aclMeanInfo,
44 &aclVarInfo,
45 &aclBetaInfo,
46 &aclGammaInfo,
47 descriptor.m_Eps);
48}
49
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000050NeonBatchNormalizationWorkload::NeonBatchNormalizationWorkload(
telsoa014fcda012018-03-09 14:13:49 +000051 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info)
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000052 : BaseWorkload<BatchNormalizationQueueDescriptor>(descriptor, info)
telsoa014fcda012018-03-09 14:13:49 +000053{
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000054 m_Data.ValidateInputsOutputs("NeonBatchNormalizationWorkload", 1, 1);
telsoa014fcda012018-03-09 14:13:49 +000055
Derek Lambertic81855f2019-06-13 17:34:19 +010056 arm_compute::ITensor& input = boost::polymorphic_downcast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
57 arm_compute::ITensor& output = boost::polymorphic_downcast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
telsoa014fcda012018-03-09 14:13:49 +000058
Matthew Bentham8800c002018-11-19 13:19:28 +000059 arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010060 input.info()->set_data_layout(aclDataLayout);
61 output.info()->set_data_layout(aclDataLayout);
62
telsoa01c577f2c2018-08-31 09:22:23 +010063 m_Mean = std::make_unique<arm_compute::Tensor>();
64 BuildArmComputeTensor(*m_Mean, m_Data.m_Mean->GetTensorInfo());
telsoa014fcda012018-03-09 14:13:49 +000065
telsoa01c577f2c2018-08-31 09:22:23 +010066 m_Variance = std::make_unique<arm_compute::Tensor>();
67 BuildArmComputeTensor(*m_Variance, m_Data.m_Variance->GetTensorInfo());
telsoa014fcda012018-03-09 14:13:49 +000068
telsoa01c577f2c2018-08-31 09:22:23 +010069 m_Gamma = std::make_unique<arm_compute::Tensor>();
70 BuildArmComputeTensor(*m_Gamma, m_Data.m_Gamma->GetTensorInfo());
71
72 m_Beta = std::make_unique<arm_compute::Tensor>();
73 BuildArmComputeTensor(*m_Beta, m_Data.m_Beta->GetTensorInfo());
74
Matthew Benthamd80a7122019-01-08 17:52:37 +000075 auto layer = std::make_unique<arm_compute::NEBatchNormalizationLayer>();
76 layer->configure(&input,
77 &output,
78 m_Mean.get(),
79 m_Variance.get(),
80 m_Beta.get(),
81 m_Gamma.get(),
82 m_Data.m_Parameters.m_Eps);
83 m_Layer.reset(layer.release());
telsoa01c577f2c2018-08-31 09:22:23 +010084
Nattapat Chaimanowong177d8d22018-10-16 13:21:27 +010085 InitializeArmComputeTensorData(*m_Mean, m_Data.m_Mean);
86 InitializeArmComputeTensorData(*m_Variance, m_Data.m_Variance);
87 InitializeArmComputeTensorData(*m_Gamma, m_Data.m_Gamma);
88 InitializeArmComputeTensorData(*m_Beta, m_Data.m_Beta);
telsoa01c577f2c2018-08-31 09:22:23 +010089
90 // Force Compute Library to perform the necessary copying and reshaping, after which
91 // delete all the input tensors that will no longer be needed
Matthew Benthamd80a7122019-01-08 17:52:37 +000092 m_Layer->prepare();
telsoa01c577f2c2018-08-31 09:22:23 +010093 FreeUnusedTensors();
telsoa014fcda012018-03-09 14:13:49 +000094}
95
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000096void NeonBatchNormalizationWorkload::Execute() const
telsoa014fcda012018-03-09 14:13:49 +000097{
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000098 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonBatchNormalizationWorkload_Execute");
Matthew Benthamd80a7122019-01-08 17:52:37 +000099 m_Layer->run();
telsoa014fcda012018-03-09 14:13:49 +0000100}
101
Matthew Benthamc48ac8c2018-12-12 16:15:59 +0000102void NeonBatchNormalizationWorkload::FreeUnusedTensors()
telsoa01c577f2c2018-08-31 09:22:23 +0100103{
104 FreeTensorIfUnused(m_Mean);
105 FreeTensorIfUnused(m_Variance);
106 FreeTensorIfUnused(m_Gamma);
107 FreeTensorIfUnused(m_Beta);
108}
109
telsoa014fcda012018-03-09 14:13:49 +0000110} //namespace armnn