blob: fa0be85100222965993c818a86c0fcada66f0331 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
arovir019e53a352018-08-31 15:26:35 +01006#include "ClBatchNormalizationFloatWorkload.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include <cl/ClTensorHandle.hpp>
8#include <backendsCommon/CpuTensorHandle.hpp>
9#include <aclCommon/ArmComputeTensorUtils.hpp>
10#include <cl/ClLayerSupport.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
Matthew Bentham14e46692018-09-20 15:35:30 +010012#include "ClWorkloadUtils.hpp"
13
telsoa014fcda012018-03-09 14:13:49 +000014namespace armnn
15{
16using namespace armcomputetensorutils;
17
telsoa01c577f2c2018-08-31 09:22:23 +010018arm_compute::Status ClBatchNormalizationValidate(const TensorInfo& input,
19 const TensorInfo& output,
20 const TensorInfo& mean,
21 const TensorInfo& var,
22 const TensorInfo& beta,
23 const TensorInfo& gamma,
24 const BatchNormalizationDescriptor &desc)
25{
Nikhil Rajd1340932018-10-18 14:27:50 +010026 const arm_compute::TensorInfo aclInputInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000027 armcomputetensorutils::BuildArmComputeTensorInfo(input, desc.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010028 const arm_compute::TensorInfo aclOutputInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000029 armcomputetensorutils::BuildArmComputeTensorInfo(output, desc.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010030 const arm_compute::TensorInfo aclMeanInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000031 armcomputetensorutils::BuildArmComputeTensorInfo(mean, desc.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010032 const arm_compute::TensorInfo aclVarInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000033 armcomputetensorutils::BuildArmComputeTensorInfo(var, desc.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010034 const arm_compute::TensorInfo aclBetaInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000035 armcomputetensorutils::BuildArmComputeTensorInfo(beta, desc.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010036 const arm_compute::TensorInfo aclGammaInfo =
Matthew Bentham8800c002018-11-19 13:19:28 +000037 armcomputetensorutils::BuildArmComputeTensorInfo(gamma, desc.m_DataLayout);
telsoa01c577f2c2018-08-31 09:22:23 +010038
39 return arm_compute::CLBatchNormalizationLayer::validate(&aclInputInfo,
40 &aclOutputInfo,
41 &aclMeanInfo,
42 &aclVarInfo,
43 &aclBetaInfo,
44 &aclGammaInfo,
45 desc.m_Eps);
46}
47
arovir019e53a352018-08-31 15:26:35 +010048ClBatchNormalizationFloatWorkload::ClBatchNormalizationFloatWorkload(
telsoa014fcda012018-03-09 14:13:49 +000049 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info)
telsoa01c577f2c2018-08-31 09:22:23 +010050 : FloatWorkload<BatchNormalizationQueueDescriptor>(descriptor, info)
telsoa014fcda012018-03-09 14:13:49 +000051{
telsoa01c577f2c2018-08-31 09:22:23 +010052 m_Mean = std::make_unique<arm_compute::CLTensor>();
53 BuildArmComputeTensor(*m_Mean, m_Data.m_Mean->GetTensorInfo());
54
55 m_Variance = std::make_unique<arm_compute::CLTensor>();
56 BuildArmComputeTensor(*m_Variance, m_Data.m_Variance->GetTensorInfo());
57
58 m_Gamma = std::make_unique<arm_compute::CLTensor>();
59 BuildArmComputeTensor(*m_Gamma, m_Data.m_Gamma->GetTensorInfo());
60
61 m_Beta = std::make_unique<arm_compute::CLTensor>();
62 BuildArmComputeTensor(*m_Beta, m_Data.m_Beta->GetTensorInfo());
telsoa014fcda012018-03-09 14:13:49 +000063
arovir019e53a352018-08-31 15:26:35 +010064 m_Data.ValidateInputsOutputs("ClBatchNormalizationFloatWorkload", 1, 1);
telsoa014fcda012018-03-09 14:13:49 +000065
66 arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
67 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
telsoa014fcda012018-03-09 14:13:49 +000068
Matthew Bentham8800c002018-11-19 13:19:28 +000069 arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
Nikhil Rajd1340932018-10-18 14:27:50 +010070 input.info()->set_data_layout(aclDataLayout);
71 output.info()->set_data_layout(aclDataLayout);
72
telsoa01c577f2c2018-08-31 09:22:23 +010073 m_Layer.configure(&input,
74 &output,
75 m_Mean.get(),
76 m_Variance.get(),
77 m_Beta.get(),
78 m_Gamma.get(),
79 m_Data.m_Parameters.m_Eps);
80
Matthew Bentham785df502018-09-21 10:29:58 +010081 InitializeArmComputeClTensorData(*m_Mean, m_Data.m_Mean);
82 InitializeArmComputeClTensorData(*m_Variance, m_Data.m_Variance);
83 InitializeArmComputeClTensorData(*m_Beta, m_Data.m_Beta);
84 InitializeArmComputeClTensorData(*m_Gamma, m_Data.m_Gamma);
telsoa01c577f2c2018-08-31 09:22:23 +010085
86 // Force Compute Library to perform the necessary copying and reshaping, after which
87 // delete all the input tensors that will no longer be needed
88 m_Layer.prepare();
89 FreeUnusedTensors();
telsoa014fcda012018-03-09 14:13:49 +000090}
91
arovir019e53a352018-08-31 15:26:35 +010092void ClBatchNormalizationFloatWorkload::Execute() const
telsoa014fcda012018-03-09 14:13:49 +000093{
arovir019e53a352018-08-31 15:26:35 +010094 ARMNN_SCOPED_PROFILING_EVENT_CL("ClBatchNormalizationFloatWorkload_Execute");
Aron Virginas-Tara8e06ed2018-10-19 16:46:15 +010095 RunClFunction(m_Layer, CHECK_LOCATION());
telsoa014fcda012018-03-09 14:13:49 +000096}
97
arovir019e53a352018-08-31 15:26:35 +010098void ClBatchNormalizationFloatWorkload::FreeUnusedTensors()
telsoa01c577f2c2018-08-31 09:22:23 +010099{
100 FreeTensorIfUnused(m_Mean);
101 FreeTensorIfUnused(m_Variance);
102 FreeTensorIfUnused(m_Gamma);
103 FreeTensorIfUnused(m_Beta);
104}
105
Matthew Bentham14e46692018-09-20 15:35:30 +0100106} //namespace armnn