blob: 72565966b85afb561f3aba9d05edec16e87e8af9 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
arovir019e53a352018-08-31 15:26:35 +01006#include "ClConvolution2dFloatWorkload.hpp"
David Beck711fa312018-09-24 10:46:38 +01007#include <backends/ClTensorHandle.hpp>
8#include <backends/CpuTensorHandle.hpp>
9#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
10#include <backends/ClLayerSupport.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
Matthew Bentham14e46692018-09-20 15:35:30 +010012#include "ClWorkloadUtils.hpp"
13
telsoa014fcda012018-03-09 14:13:49 +000014namespace armnn
15{
16using namespace armcomputetensorutils;
17
arovir019e53a352018-08-31 15:26:35 +010018ClConvolution2dFloatWorkload::ClConvolution2dFloatWorkload(const Convolution2dQueueDescriptor& descriptor,
surmeh013537c2c2018-05-18 16:31:43 +010019 const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
telsoa01c577f2c2018-08-31 09:22:23 +010020 : FloatWorkload<Convolution2dQueueDescriptor>(descriptor, info)
surmeh013537c2c2018-05-18 16:31:43 +010021 , m_ConvolutionLayer(memoryManager)
telsoa014fcda012018-03-09 14:13:49 +000022{
23
telsoa01c577f2c2018-08-31 09:22:23 +010024 // todo: check tensor shapes match.
telsoa014fcda012018-03-09 14:13:49 +000025 const TensorInfo& weightInfo = m_Data.m_Weight->GetTensorInfo();
telsoa01c577f2c2018-08-31 09:22:23 +010026
27 m_KernelTensor = std::make_unique<arm_compute::CLTensor>();
Francis Murtagh351d13d2018-09-24 15:01:18 +010028 BuildArmComputeTensor(*m_KernelTensor, weightInfo, descriptor.m_DataLayout);
telsoa014fcda012018-03-09 14:13:49 +000029
30 arm_compute::PadStrideInfo padStrideInfo(m_Data.m_Parameters.m_StrideX,
31 m_Data.m_Parameters.m_StrideY,
32 m_Data.m_Parameters.m_PadLeft,
33 m_Data.m_Parameters.m_PadRight,
34 m_Data.m_Parameters.m_PadTop,
35 m_Data.m_Parameters.m_PadBottom,
36 arm_compute::DimensionRoundingType::FLOOR);
37
telsoa014fcda012018-03-09 14:13:49 +000038 if (m_Data.m_Parameters.m_BiasEnabled)
39 {
telsoa01c577f2c2018-08-31 09:22:23 +010040 m_BiasTensor = std::make_unique<arm_compute::CLTensor>();
Francis Murtagh351d13d2018-09-24 15:01:18 +010041 BuildArmComputeTensor(*m_BiasTensor, m_Data.m_Bias->GetTensorInfo(), descriptor.m_DataLayout);
telsoa014fcda012018-03-09 14:13:49 +000042 }
43
44 m_Data.ValidateInputsOutputs("ClConvolution2dFloat32Workload", 1, 1);
45
46 arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
47 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
48
surmeh013537c2c2018-05-18 16:31:43 +010049 m_ConvolutionLayer.configure(&input,
telsoa01c577f2c2018-08-31 09:22:23 +010050 m_KernelTensor.get(),
51 m_BiasTensor.get(),
surmeh013537c2c2018-05-18 16:31:43 +010052 &output,
53 padStrideInfo);
telsoa014fcda012018-03-09 14:13:49 +000054
Matthew Bentham785df502018-09-21 10:29:58 +010055 InitializeArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight);
telsoa014fcda012018-03-09 14:13:49 +000056
telsoa01c577f2c2018-08-31 09:22:23 +010057 if (m_BiasTensor)
telsoa014fcda012018-03-09 14:13:49 +000058 {
Matthew Bentham785df502018-09-21 10:29:58 +010059 InitializeArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias);
telsoa014fcda012018-03-09 14:13:49 +000060 }
telsoa01c577f2c2018-08-31 09:22:23 +010061
62 // Force Compute Library to perform the necessary copying and reshaping, after which
63 // delete all the input tensors that will no longer be needed
64 m_ConvolutionLayer.prepare();
65 FreeUnusedTensors();
telsoa014fcda012018-03-09 14:13:49 +000066}
67
arovir019e53a352018-08-31 15:26:35 +010068void ClConvolution2dFloatWorkload::Execute() const
telsoa014fcda012018-03-09 14:13:49 +000069{
telsoa01c577f2c2018-08-31 09:22:23 +010070 ARMNN_SCOPED_PROFILING_EVENT_CL("ClConvolution2dFloat32Workload_Execute");
telsoa014fcda012018-03-09 14:13:49 +000071
surmeh013537c2c2018-05-18 16:31:43 +010072 m_ConvolutionLayer.run();
telsoa014fcda012018-03-09 14:13:49 +000073}
74
arovir019e53a352018-08-31 15:26:35 +010075void ClConvolution2dFloatWorkload::FreeUnusedTensors()
telsoa01c577f2c2018-08-31 09:22:23 +010076{
77 FreeTensorIfUnused(m_KernelTensor);
78 FreeTensorIfUnused(m_BiasTensor);
79}
80
surmeh013537c2c2018-05-18 16:31:43 +010081} //namespace armnn