blob: f0b9a46d600955baa41d65adb2e58f1cfe26a143 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
arovir019e53a352018-08-31 15:26:35 +01006#include "ClConvolution2dFloatWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +00007#include "backends/ClTensorHandle.hpp"
8#include "backends/CpuTensorHandle.hpp"
9#include "backends/ArmComputeTensorUtils.hpp"
10#include "backends/ClLayerSupport.hpp"
11
Matthew Bentham14e46692018-09-20 15:35:30 +010012#include "ClWorkloadUtils.hpp"
13
telsoa014fcda012018-03-09 14:13:49 +000014namespace armnn
15{
16using namespace armcomputetensorutils;
17
arovir019e53a352018-08-31 15:26:35 +010018ClConvolution2dFloatWorkload::ClConvolution2dFloatWorkload(const Convolution2dQueueDescriptor& descriptor,
surmeh013537c2c2018-05-18 16:31:43 +010019 const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
telsoa01c577f2c2018-08-31 09:22:23 +010020 : FloatWorkload<Convolution2dQueueDescriptor>(descriptor, info)
surmeh013537c2c2018-05-18 16:31:43 +010021 , m_ConvolutionLayer(memoryManager)
telsoa014fcda012018-03-09 14:13:49 +000022{
23
telsoa01c577f2c2018-08-31 09:22:23 +010024 // todo: check tensor shapes match.
telsoa014fcda012018-03-09 14:13:49 +000025 const TensorInfo& weightInfo = m_Data.m_Weight->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010026
27 m_KernelTensor = std::make_unique<arm_compute::CLTensor>();
28 BuildArmComputeTensor(*m_KernelTensor, weightInfo);
telsoa014fcda012018-03-09 14:13:49 +000029
30 arm_compute::PadStrideInfo padStrideInfo(m_Data.m_Parameters.m_StrideX,
31 m_Data.m_Parameters.m_StrideY,
32 m_Data.m_Parameters.m_PadLeft,
33 m_Data.m_Parameters.m_PadRight,
34 m_Data.m_Parameters.m_PadTop,
35 m_Data.m_Parameters.m_PadBottom,
36 arm_compute::DimensionRoundingType::FLOOR);
37
telsoa014fcda012018-03-09 14:13:49 +000038 if (m_Data.m_Parameters.m_BiasEnabled)
39 {
telsoa01c577f2c2018-08-31 09:22:23 +010040 m_BiasTensor = std::make_unique<arm_compute::CLTensor>();
41 BuildArmComputeTensor(*m_BiasTensor, m_Data.m_Bias->GetTensorInfo());
telsoa014fcda012018-03-09 14:13:49 +000042 }
43
44 m_Data.ValidateInputsOutputs("ClConvolution2dFloat32Workload", 1, 1);
45
46 arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
47 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
48
surmeh013537c2c2018-05-18 16:31:43 +010049 m_ConvolutionLayer.configure(&input,
telsoa01c577f2c2018-08-31 09:22:23 +010050 m_KernelTensor.get(),
51 m_BiasTensor.get(),
surmeh013537c2c2018-05-18 16:31:43 +010052 &output,
53 padStrideInfo);
telsoa014fcda012018-03-09 14:13:49 +000054
Matthew Bentham785df502018-09-21 10:29:58 +010055 InitializeArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight);
telsoa014fcda012018-03-09 14:13:49 +000056
telsoa01c577f2c2018-08-31 09:22:23 +010057 if (m_BiasTensor)
telsoa014fcda012018-03-09 14:13:49 +000058 {
Matthew Bentham785df502018-09-21 10:29:58 +010059 InitializeArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias);
telsoa014fcda012018-03-09 14:13:49 +000060 }
telsoa01c577f2c2018-08-31 09:22:23 +010061
62 // Force Compute Library to perform the necessary copying and reshaping, after which
63 // delete all the input tensors that will no longer be needed
64 m_ConvolutionLayer.prepare();
65 FreeUnusedTensors();
telsoa014fcda012018-03-09 14:13:49 +000066}
67
arovir019e53a352018-08-31 15:26:35 +010068void ClConvolution2dFloatWorkload::Execute() const
telsoa014fcda012018-03-09 14:13:49 +000069{
telsoa01c577f2c2018-08-31 09:22:23 +010070 ARMNN_SCOPED_PROFILING_EVENT_CL("ClConvolution2dFloat32Workload_Execute");
telsoa014fcda012018-03-09 14:13:49 +000071
surmeh013537c2c2018-05-18 16:31:43 +010072 m_ConvolutionLayer.run();
telsoa014fcda012018-03-09 14:13:49 +000073}
74
arovir019e53a352018-08-31 15:26:35 +010075void ClConvolution2dFloatWorkload::FreeUnusedTensors()
telsoa01c577f2c2018-08-31 09:22:23 +010076{
77 FreeTensorIfUnused(m_KernelTensor);
78 FreeTensorIfUnused(m_BiasTensor);
79}
80
surmeh013537c2c2018-05-18 16:31:43 +010081} //namespace armnn