blob: 5bff7a63c9d9c5e89f4bf1e649c3b21df091d22e [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
arovir019e53a352018-08-31 15:26:35 +01006#include "ClBatchNormalizationFloatWorkload.hpp"
David Beckac42efd2018-09-26 17:41:13 +01007#include <backends/cl/ClTensorHandle.hpp>
David Beck711fa312018-09-24 10:46:38 +01008#include <backends/CpuTensorHandle.hpp>
9#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
David Beckac42efd2018-09-26 17:41:13 +010010#include <backends/cl/ClLayerSupport.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
Matthew Bentham14e46692018-09-20 15:35:30 +010012#include "ClWorkloadUtils.hpp"
13
telsoa014fcda012018-03-09 14:13:49 +000014namespace armnn
15{
16using namespace armcomputetensorutils;
17
telsoa01c577f2c2018-08-31 09:22:23 +010018arm_compute::Status ClBatchNormalizationValidate(const TensorInfo& input,
19 const TensorInfo& output,
20 const TensorInfo& mean,
21 const TensorInfo& var,
22 const TensorInfo& beta,
23 const TensorInfo& gamma,
24 const BatchNormalizationDescriptor &desc)
25{
26 const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
27 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
28 const arm_compute::TensorInfo aclMeanInfo = BuildArmComputeTensorInfo(mean);
29 const arm_compute::TensorInfo aclVarInfo = BuildArmComputeTensorInfo(var);
30 const arm_compute::TensorInfo aclBetaInfo = BuildArmComputeTensorInfo(beta);
31 const arm_compute::TensorInfo aclGammaInfo = BuildArmComputeTensorInfo(gamma);
32
33 return arm_compute::CLBatchNormalizationLayer::validate(&aclInputInfo,
34 &aclOutputInfo,
35 &aclMeanInfo,
36 &aclVarInfo,
37 &aclBetaInfo,
38 &aclGammaInfo,
39 desc.m_Eps);
40}
41
arovir019e53a352018-08-31 15:26:35 +010042ClBatchNormalizationFloatWorkload::ClBatchNormalizationFloatWorkload(
telsoa014fcda012018-03-09 14:13:49 +000043 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info)
telsoa01c577f2c2018-08-31 09:22:23 +010044 : FloatWorkload<BatchNormalizationQueueDescriptor>(descriptor, info)
telsoa014fcda012018-03-09 14:13:49 +000045{
telsoa01c577f2c2018-08-31 09:22:23 +010046 m_Mean = std::make_unique<arm_compute::CLTensor>();
47 BuildArmComputeTensor(*m_Mean, m_Data.m_Mean->GetTensorInfo());
48
49 m_Variance = std::make_unique<arm_compute::CLTensor>();
50 BuildArmComputeTensor(*m_Variance, m_Data.m_Variance->GetTensorInfo());
51
52 m_Gamma = std::make_unique<arm_compute::CLTensor>();
53 BuildArmComputeTensor(*m_Gamma, m_Data.m_Gamma->GetTensorInfo());
54
55 m_Beta = std::make_unique<arm_compute::CLTensor>();
56 BuildArmComputeTensor(*m_Beta, m_Data.m_Beta->GetTensorInfo());
telsoa014fcda012018-03-09 14:13:49 +000057
arovir019e53a352018-08-31 15:26:35 +010058 m_Data.ValidateInputsOutputs("ClBatchNormalizationFloatWorkload", 1, 1);
telsoa014fcda012018-03-09 14:13:49 +000059
60 arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
61 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
telsoa014fcda012018-03-09 14:13:49 +000062
telsoa01c577f2c2018-08-31 09:22:23 +010063 m_Layer.configure(&input,
64 &output,
65 m_Mean.get(),
66 m_Variance.get(),
67 m_Beta.get(),
68 m_Gamma.get(),
69 m_Data.m_Parameters.m_Eps);
70
Matthew Bentham785df502018-09-21 10:29:58 +010071 InitializeArmComputeClTensorData(*m_Mean, m_Data.m_Mean);
72 InitializeArmComputeClTensorData(*m_Variance, m_Data.m_Variance);
73 InitializeArmComputeClTensorData(*m_Beta, m_Data.m_Beta);
74 InitializeArmComputeClTensorData(*m_Gamma, m_Data.m_Gamma);
telsoa01c577f2c2018-08-31 09:22:23 +010075
76 // Force Compute Library to perform the necessary copying and reshaping, after which
77 // delete all the input tensors that will no longer be needed
78 m_Layer.prepare();
79 FreeUnusedTensors();
telsoa014fcda012018-03-09 14:13:49 +000080}
81
arovir019e53a352018-08-31 15:26:35 +010082void ClBatchNormalizationFloatWorkload::Execute() const
telsoa014fcda012018-03-09 14:13:49 +000083{
arovir019e53a352018-08-31 15:26:35 +010084 ARMNN_SCOPED_PROFILING_EVENT_CL("ClBatchNormalizationFloatWorkload_Execute");
telsoa014fcda012018-03-09 14:13:49 +000085 m_Layer.run();
86}
87
arovir019e53a352018-08-31 15:26:35 +010088void ClBatchNormalizationFloatWorkload::FreeUnusedTensors()
telsoa01c577f2c2018-08-31 09:22:23 +010089{
90 FreeTensorIfUnused(m_Mean);
91 FreeTensorIfUnused(m_Variance);
92 FreeTensorIfUnused(m_Gamma);
93 FreeTensorIfUnused(m_Beta);
94}
95
Matthew Bentham14e46692018-09-20 15:35:30 +010096} //namespace armnn