blob: 361d6f87a5b25a1df45712128d9e06bfc47fb14c [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
arovir019e53a352018-08-31 15:26:35 +01006#include "ClBatchNormalizationFloatWorkload.hpp"
Matthew Bentham14e46692018-09-20 15:35:30 +01007#include "ClWorkloadUtils.hpp"
8
Mike Kelly07810fc2020-11-12 10:58:48 +00009#include <aclCommon/ArmComputeTensorUtils.hpp>
10#include <aclCommon/ArmComputeUtils.hpp>
James Conroy1f58f032021-04-27 17:13:27 +010011#include <backendsCommon/TensorHandle.hpp>
Mike Kelly07810fc2020-11-12 10:58:48 +000012#include <cl/ClLayerSupport.hpp>
Mike Kelly0d4ed392020-11-13 15:26:41 +000013#include <cl/ClTensorHandle.hpp>
Mike Kelly07810fc2020-11-12 10:58:48 +000014
telsoa014fcda012018-03-09 14:13:49 +000015namespace armnn
16{
17using namespace armcomputetensorutils;
18
telsoa01c577f2c2018-08-31 09:22:23 +010019arm_compute::Status ClBatchNormalizationValidate(const TensorInfo& input,
20 const TensorInfo& output,
21 const TensorInfo& mean,
22 const TensorInfo& var,
23 const TensorInfo& beta,
24 const TensorInfo& gamma,
Mike Kelly07810fc2020-11-12 10:58:48 +000025 const BatchNormalizationDescriptor& desc,
26 const ActivationDescriptor* activationDescriptor)
telsoa01c577f2c2018-08-31 09:22:23 +010027{
Nikhil Rajd1340932018-10-18 14:27:50 +010028 const arm_compute::TensorInfo aclInputInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000029 armcomputetensorutils::BuildArmComputeTensorInfo(input, desc.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010030 const arm_compute::TensorInfo aclOutputInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000031 armcomputetensorutils::BuildArmComputeTensorInfo(output, desc.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010032 const arm_compute::TensorInfo aclMeanInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000033 armcomputetensorutils::BuildArmComputeTensorInfo(mean, desc.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010034 const arm_compute::TensorInfo aclVarInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000035 armcomputetensorutils::BuildArmComputeTensorInfo(var, desc.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010036 const arm_compute::TensorInfo aclBetaInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000037 armcomputetensorutils::BuildArmComputeTensorInfo(beta, desc.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010038 const arm_compute::TensorInfo aclGammaInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000039 armcomputetensorutils::BuildArmComputeTensorInfo(gamma, desc.m_DataLayout);
telsoa01c577f2c2018-08-31 09:22:23 +010040
Mike Kelly07810fc2020-11-12 10:58:48 +000041 const arm_compute::ActivationLayerInfo activationInfo = ConvertActivationDescriptorToAclActivationLayerInfo(
42 activationDescriptor);
43
telsoa01c577f2c2018-08-31 09:22:23 +010044 return arm_compute::CLBatchNormalizationLayer::validate(&aclInputInfo,
45 &aclOutputInfo,
46 &aclMeanInfo,
47 &aclVarInfo,
48 &aclBetaInfo,
49 &aclGammaInfo,
Mike Kelly07810fc2020-11-12 10:58:48 +000050 desc.m_Eps,
51 activationInfo);
telsoa01c577f2c2018-08-31 09:22:23 +010052}
53
arovir019e53a352018-08-31 15:26:35 +010054ClBatchNormalizationFloatWorkload::ClBatchNormalizationFloatWorkload(
Sadik Armagane9444752020-12-02 11:28:58 +000055 const BatchNormalizationQueueDescriptor& descriptor,
56 const WorkloadInfo& info,
57 const arm_compute::CLCompileContext& clCompileContext)
telsoa01c577f2c2018-08-31 09:22:23 +010058 : FloatWorkload<BatchNormalizationQueueDescriptor>(descriptor, info)
telsoa014fcda012018-03-09 14:13:49 +000059{
telsoa01c577f2c2018-08-31 09:22:23 +010060 m_Mean = std::make_unique<arm_compute::CLTensor>();
61 BuildArmComputeTensor(*m_Mean, m_Data.m_Mean->GetTensorInfo());
62
63 m_Variance = std::make_unique<arm_compute::CLTensor>();
64 BuildArmComputeTensor(*m_Variance, m_Data.m_Variance->GetTensorInfo());
65
66 m_Gamma = std::make_unique<arm_compute::CLTensor>();
67 BuildArmComputeTensor(*m_Gamma, m_Data.m_Gamma->GetTensorInfo());
68
69 m_Beta = std::make_unique<arm_compute::CLTensor>();
70 BuildArmComputeTensor(*m_Beta, m_Data.m_Beta->GetTensorInfo());
telsoa014fcda012018-03-09 14:13:49 +000071
arovir019e53a352018-08-31 15:26:35 +010072 m_Data.ValidateInputsOutputs("ClBatchNormalizationFloatWorkload", 1, 1);
telsoa014fcda012018-03-09 14:13:49 +000073
74 arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
75 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
telsoa014fcda012018-03-09 14:13:49 +000076
Matthew Bentham8800c002018-11-19 13:19:28 +000077 arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010078 input.info()->set_data_layout(aclDataLayout);
79 output.info()->set_data_layout(aclDataLayout);
80
Mike Kelly07810fc2020-11-12 10:58:48 +000081 const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
82
Sadik Armagane9444752020-12-02 11:28:58 +000083 m_Layer.configure(clCompileContext,
84 &input,
telsoa01c577f2c2018-08-31 09:22:23 +010085 &output,
86 m_Mean.get(),
87 m_Variance.get(),
88 m_Beta.get(),
89 m_Gamma.get(),
Mike Kelly07810fc2020-11-12 10:58:48 +000090 m_Data.m_Parameters.m_Eps,
91 activationInfo);
telsoa01c577f2c2018-08-31 09:22:23 +010092
Matthew Bentham785df502018-09-21 10:29:58 +010093 InitializeArmComputeClTensorData(*m_Mean, m_Data.m_Mean);
94 InitializeArmComputeClTensorData(*m_Variance, m_Data.m_Variance);
95 InitializeArmComputeClTensorData(*m_Beta, m_Data.m_Beta);
96 InitializeArmComputeClTensorData(*m_Gamma, m_Data.m_Gamma);
telsoa01c577f2c2018-08-31 09:22:23 +010097
98 // Force Compute Library to perform the necessary copying and reshaping, after which
99 // delete all the input tensors that will no longer be needed
100 m_Layer.prepare();
101 FreeUnusedTensors();
telsoa014fcda012018-03-09 14:13:49 +0000102}
103
arovir019e53a352018-08-31 15:26:35 +0100104void ClBatchNormalizationFloatWorkload::Execute() const
telsoa014fcda012018-03-09 14:13:49 +0000105{
arovir019e53a352018-08-31 15:26:35 +0100106 ARMNN_SCOPED_PROFILING_EVENT_CL("ClBatchNormalizationFloatWorkload_Execute");
Aron Virginas-Tara8e06ed2018-10-19 16:46:15 +0100107 RunClFunction(m_Layer, CHECK_LOCATION());
telsoa014fcda012018-03-09 14:13:49 +0000108}
109
arovir019e53a352018-08-31 15:26:35 +0100110void ClBatchNormalizationFloatWorkload::FreeUnusedTensors()
telsoa01c577f2c2018-08-31 09:22:23 +0100111{
112 FreeTensorIfUnused(m_Mean);
113 FreeTensorIfUnused(m_Variance);
114 FreeTensorIfUnused(m_Gamma);
115 FreeTensorIfUnused(m_Beta);
116}
117
Matthew Bentham14e46692018-09-20 15:35:30 +0100118} //namespace armnn