blob: 389605f17d5c56958cd08c599f747477d7ec4c5f [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Teresa Charlin588cbdf2022-01-19 15:55:37 +00002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
arovir019e53a352018-08-31 15:26:35 +01006#include "ClBatchNormalizationFloatWorkload.hpp"
Matthew Bentham14e46692018-09-20 15:35:30 +01007#include "ClWorkloadUtils.hpp"
8
Mike Kelly07810fc2020-11-12 10:58:48 +00009#include <aclCommon/ArmComputeTensorUtils.hpp>
10#include <aclCommon/ArmComputeUtils.hpp>
Colm Donelan0c479742021-12-10 12:43:54 +000011#include <armnn/backends/TensorHandle.hpp>
Mike Kelly07810fc2020-11-12 10:58:48 +000012#include <cl/ClLayerSupport.hpp>
Mike Kelly0d4ed392020-11-13 15:26:41 +000013#include <cl/ClTensorHandle.hpp>
Mike Kelly07810fc2020-11-12 10:58:48 +000014
telsoa014fcda012018-03-09 14:13:49 +000015namespace armnn
16{
17using namespace armcomputetensorutils;
18
telsoa01c577f2c2018-08-31 09:22:23 +010019arm_compute::Status ClBatchNormalizationValidate(const TensorInfo& input,
20 const TensorInfo& output,
21 const TensorInfo& mean,
22 const TensorInfo& var,
23 const TensorInfo& beta,
24 const TensorInfo& gamma,
Keith Davisbcd860a2021-08-05 14:20:33 +010025 const BatchNormalizationDescriptor& descriptor,
Mike Kelly07810fc2020-11-12 10:58:48 +000026 const ActivationDescriptor* activationDescriptor)
telsoa01c577f2c2018-08-31 09:22:23 +010027{
Nikhil Rajd1340932018-10-18 14:27:50 +010028 const arm_compute::TensorInfo aclInputInfo =
Keith Davisbcd860a2021-08-05 14:20:33 +010029 armcomputetensorutils::BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010030 const arm_compute::TensorInfo aclOutputInfo =
Keith Davisbcd860a2021-08-05 14:20:33 +010031 armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010032 const arm_compute::TensorInfo aclMeanInfo =
Keith Davisbcd860a2021-08-05 14:20:33 +010033 armcomputetensorutils::BuildArmComputeTensorInfo(mean, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010034 const arm_compute::TensorInfo aclVarInfo =
Keith Davisbcd860a2021-08-05 14:20:33 +010035 armcomputetensorutils::BuildArmComputeTensorInfo(var, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010036 const arm_compute::TensorInfo aclBetaInfo =
Keith Davisbcd860a2021-08-05 14:20:33 +010037 armcomputetensorutils::BuildArmComputeTensorInfo(beta, descriptor.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010038 const arm_compute::TensorInfo aclGammaInfo =
Keith Davisbcd860a2021-08-05 14:20:33 +010039 armcomputetensorutils::BuildArmComputeTensorInfo(gamma, descriptor.m_DataLayout);
telsoa01c577f2c2018-08-31 09:22:23 +010040
Mike Kelly07810fc2020-11-12 10:58:48 +000041 const arm_compute::ActivationLayerInfo activationInfo = ConvertActivationDescriptorToAclActivationLayerInfo(
42 activationDescriptor);
43
telsoa01c577f2c2018-08-31 09:22:23 +010044 return arm_compute::CLBatchNormalizationLayer::validate(&aclInputInfo,
45 &aclOutputInfo,
46 &aclMeanInfo,
47 &aclVarInfo,
48 &aclBetaInfo,
49 &aclGammaInfo,
Keith Davisbcd860a2021-08-05 14:20:33 +010050 descriptor.m_Eps,
Mike Kelly07810fc2020-11-12 10:58:48 +000051 activationInfo);
telsoa01c577f2c2018-08-31 09:22:23 +010052}
53
arovir019e53a352018-08-31 15:26:35 +010054ClBatchNormalizationFloatWorkload::ClBatchNormalizationFloatWorkload(
Sadik Armagane9444752020-12-02 11:28:58 +000055 const BatchNormalizationQueueDescriptor& descriptor,
56 const WorkloadInfo& info,
57 const arm_compute::CLCompileContext& clCompileContext)
telsoa01c577f2c2018-08-31 09:22:23 +010058 : FloatWorkload<BatchNormalizationQueueDescriptor>(descriptor, info)
telsoa014fcda012018-03-09 14:13:49 +000059{
Keith Davisbcd860a2021-08-05 14:20:33 +010060 // Report Profiling Details
61 ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClBatchNormalizationWorkload_Construct",
62 descriptor.m_Parameters,
63 info,
64 this->GetGuid());
65
telsoa01c577f2c2018-08-31 09:22:23 +010066 m_Mean = std::make_unique<arm_compute::CLTensor>();
67 BuildArmComputeTensor(*m_Mean, m_Data.m_Mean->GetTensorInfo());
68
69 m_Variance = std::make_unique<arm_compute::CLTensor>();
70 BuildArmComputeTensor(*m_Variance, m_Data.m_Variance->GetTensorInfo());
71
72 m_Gamma = std::make_unique<arm_compute::CLTensor>();
73 BuildArmComputeTensor(*m_Gamma, m_Data.m_Gamma->GetTensorInfo());
74
75 m_Beta = std::make_unique<arm_compute::CLTensor>();
76 BuildArmComputeTensor(*m_Beta, m_Data.m_Beta->GetTensorInfo());
telsoa014fcda012018-03-09 14:13:49 +000077
arovir019e53a352018-08-31 15:26:35 +010078 m_Data.ValidateInputsOutputs("ClBatchNormalizationFloatWorkload", 1, 1);
telsoa014fcda012018-03-09 14:13:49 +000079
80 arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
81 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
telsoa014fcda012018-03-09 14:13:49 +000082
Matthew Bentham8800c002018-11-19 13:19:28 +000083 arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010084 input.info()->set_data_layout(aclDataLayout);
85 output.info()->set_data_layout(aclDataLayout);
86
Mike Kelly07810fc2020-11-12 10:58:48 +000087 const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
88
Kevin May9f6862d2021-10-22 15:42:28 +010089 {
90 ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "ClBatchNormalizationFloatWorkload_configure");
91 m_Layer.configure(clCompileContext,
92 &input,
93 &output,
94 m_Mean.get(),
95 m_Variance.get(),
96 m_Beta.get(),
97 m_Gamma.get(),
98 m_Data.m_Parameters.m_Eps,
99 activationInfo);
100 }
telsoa01c577f2c2018-08-31 09:22:23 +0100101
Matthew Bentham785df502018-09-21 10:29:58 +0100102 InitializeArmComputeClTensorData(*m_Mean, m_Data.m_Mean);
103 InitializeArmComputeClTensorData(*m_Variance, m_Data.m_Variance);
104 InitializeArmComputeClTensorData(*m_Beta, m_Data.m_Beta);
105 InitializeArmComputeClTensorData(*m_Gamma, m_Data.m_Gamma);
telsoa01c577f2c2018-08-31 09:22:23 +0100106
107 // Force Compute Library to perform the necessary copying and reshaping, after which
108 // delete all the input tensors that will no longer be needed
109 m_Layer.prepare();
110 FreeUnusedTensors();
telsoa014fcda012018-03-09 14:13:49 +0000111}
112
arovir019e53a352018-08-31 15:26:35 +0100113void ClBatchNormalizationFloatWorkload::Execute() const
telsoa014fcda012018-03-09 14:13:49 +0000114{
Keith Davisbcd860a2021-08-05 14:20:33 +0100115 ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClBatchNormalizationFloatWorkload_Execute", this->GetGuid());
Aron Virginas-Tara8e06ed2018-10-19 16:46:15 +0100116 RunClFunction(m_Layer, CHECK_LOCATION());
telsoa014fcda012018-03-09 14:13:49 +0000117}
118
arovir019e53a352018-08-31 15:26:35 +0100119void ClBatchNormalizationFloatWorkload::FreeUnusedTensors()
telsoa01c577f2c2018-08-31 09:22:23 +0100120{
121 FreeTensorIfUnused(m_Mean);
122 FreeTensorIfUnused(m_Variance);
123 FreeTensorIfUnused(m_Gamma);
124 FreeTensorIfUnused(m_Beta);
125}
126
David Monahanec819992022-02-10 14:47:13 +0000127void ClBatchNormalizationFloatWorkload::ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
128{
129 ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
130 this->m_Data.m_Inputs[slot] = tensorHandle;
131 try
132 {
133 Reconfigure();
134 }
135 catch(armnn::UnimplementedException& e)
136 {
137 // Cannot reconfigure, revert the slot back and throw the exception.
138 this->m_Data.m_Inputs[slot] = backupHandle;
139 throw e;
140 }
141}
142
143// Replace output tensor handle with the given TensorHandle
144void ClBatchNormalizationFloatWorkload::ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
145{
146 ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
147 this->m_Data.m_Inputs[slot] = tensorHandle;
148 try
149 {
150 Reconfigure();
151 }
152 catch(armnn::UnimplementedException& e)
153 {
154 // Cannot reconfigure, revert the slot back and throw the exception.
155 this->m_Data.m_Inputs[slot] = backupHandle;
156 throw e;
157 }
158}
159
160void ClBatchNormalizationFloatWorkload::Reconfigure()
161{
162 throw armnn::UnimplementedException("Reconfigure not implemented for this workload");
163}
164
Matthew Bentham14e46692018-09-20 15:35:30 +0100165} //namespace armnn