blob: 44d50354314297ce1f9763339ef0e3b4e79fc7a4 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
Matthew Benthamc48ac8c2018-12-12 16:15:59 +00006#include "NeonBatchNormalizationWorkload.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include <backendsCommon/CpuTensorHandle.hpp>
8#include <aclCommon/ArmComputeTensorUtils.hpp>
David Beck711fa312018-09-24 10:46:38 +01009#include <armnn/ArmNN.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010
11namespace armnn
12{
13using namespace armcomputetensorutils;
14
telsoa01c577f2c2018-08-31 09:22:23 +010015
16arm_compute::Status NeonBatchNormalizationValidate(const TensorInfo& input,
17 const TensorInfo& output,
18 const TensorInfo& mean,
19 const TensorInfo& var,
20 const TensorInfo& beta,
21 const TensorInfo& gamma,
22 const BatchNormalizationDescriptor& descriptor)
23{
Nikhil Rajd1340932018-10-18 14:27:50 +010024 const arm_compute::TensorInfo aclInputInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000025 armcomputetensorutils::BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010026 const arm_compute::TensorInfo aclOutputInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000027 armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010028 const arm_compute::TensorInfo aclMeanInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000029 armcomputetensorutils::BuildArmComputeTensorInfo(mean, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010030 const arm_compute::TensorInfo aclVarInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000031 armcomputetensorutils::BuildArmComputeTensorInfo(var, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010032 const arm_compute::TensorInfo aclBetaInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000033 armcomputetensorutils::BuildArmComputeTensorInfo(beta, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010034 const arm_compute::TensorInfo aclGammaInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000035 armcomputetensorutils::BuildArmComputeTensorInfo(gamma, descriptor.m_DataLayout);
telsoa01c577f2c2018-08-31 09:22:23 +010036
37 return arm_compute::NEBatchNormalizationLayer::validate(&aclInputInfo,
38 &aclOutputInfo,
39 &aclMeanInfo,
40 &aclVarInfo,
41 &aclBetaInfo,
42 &aclGammaInfo,
43 descriptor.m_Eps);
44}
45
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000046NeonBatchNormalizationWorkload::NeonBatchNormalizationWorkload(
telsoa014fcda012018-03-09 14:13:49 +000047 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info)
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000048 : BaseWorkload<BatchNormalizationQueueDescriptor>(descriptor, info)
telsoa014fcda012018-03-09 14:13:49 +000049{
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000050 m_Data.ValidateInputsOutputs("NeonBatchNormalizationWorkload", 1, 1);
telsoa014fcda012018-03-09 14:13:49 +000051
52 arm_compute::ITensor& input = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
53 arm_compute::ITensor& output = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
54
Matthew Bentham8800c002018-11-19 13:19:28 +000055 arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010056 input.info()->set_data_layout(aclDataLayout);
57 output.info()->set_data_layout(aclDataLayout);
58
telsoa01c577f2c2018-08-31 09:22:23 +010059 m_Mean = std::make_unique<arm_compute::Tensor>();
60 BuildArmComputeTensor(*m_Mean, m_Data.m_Mean->GetTensorInfo());
telsoa014fcda012018-03-09 14:13:49 +000061
telsoa01c577f2c2018-08-31 09:22:23 +010062 m_Variance = std::make_unique<arm_compute::Tensor>();
63 BuildArmComputeTensor(*m_Variance, m_Data.m_Variance->GetTensorInfo());
telsoa014fcda012018-03-09 14:13:49 +000064
telsoa01c577f2c2018-08-31 09:22:23 +010065 m_Gamma = std::make_unique<arm_compute::Tensor>();
66 BuildArmComputeTensor(*m_Gamma, m_Data.m_Gamma->GetTensorInfo());
67
68 m_Beta = std::make_unique<arm_compute::Tensor>();
69 BuildArmComputeTensor(*m_Beta, m_Data.m_Beta->GetTensorInfo());
70
71 m_Layer.configure(&input,
72 &output,
73 m_Mean.get(),
74 m_Variance.get(),
75 m_Beta.get(),
76 m_Gamma.get(),
77 m_Data.m_Parameters.m_Eps);
78
Nattapat Chaimanowong177d8d22018-10-16 13:21:27 +010079 InitializeArmComputeTensorData(*m_Mean, m_Data.m_Mean);
80 InitializeArmComputeTensorData(*m_Variance, m_Data.m_Variance);
81 InitializeArmComputeTensorData(*m_Gamma, m_Data.m_Gamma);
82 InitializeArmComputeTensorData(*m_Beta, m_Data.m_Beta);
telsoa01c577f2c2018-08-31 09:22:23 +010083
84 // Force Compute Library to perform the necessary copying and reshaping, after which
85 // delete all the input tensors that will no longer be needed
86 m_Layer.prepare();
87 FreeUnusedTensors();
telsoa014fcda012018-03-09 14:13:49 +000088}
89
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000090void NeonBatchNormalizationWorkload::Execute() const
telsoa014fcda012018-03-09 14:13:49 +000091{
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000092 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonBatchNormalizationWorkload_Execute");
telsoa014fcda012018-03-09 14:13:49 +000093 m_Layer.run();
94}
95
Matthew Benthamc48ac8c2018-12-12 16:15:59 +000096void NeonBatchNormalizationWorkload::FreeUnusedTensors()
telsoa01c577f2c2018-08-31 09:22:23 +010097{
98 FreeTensorIfUnused(m_Mean);
99 FreeTensorIfUnused(m_Variance);
100 FreeTensorIfUnused(m_Gamma);
101 FreeTensorIfUnused(m_Beta);
102}
103
telsoa014fcda012018-03-09 14:13:49 +0000104} //namespace armnn