blob: 142cbc230fbfa60e4ae1e5a08b6f6b0b10c8f535 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
Matthew Benthamd8777392018-10-08 09:38:55 +01006#include "ClDepthwiseConvolutionWorkload.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01007
8#include "TypeUtils.hpp"
Matthew Benthamd8777392018-10-08 09:38:55 +01009#include "ClWorkloadUtils.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010010
David Beck711fa312018-09-24 10:46:38 +010011#include <backends/aclCommon/ArmComputeUtils.hpp>
12#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
David Beckac42efd2018-09-26 17:41:13 +010013#include <backends/cl/ClTensorHandle.hpp>
David Beck711fa312018-09-24 10:46:38 +010014#include <backends/CpuTensorHandle.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010015
Matthew Benthamd8777392018-10-08 09:38:55 +010016#include <arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h>
17
telsoa01c577f2c2018-08-31 09:22:23 +010018namespace armnn
19{
20
21using namespace armcomputetensorutils;
22
23arm_compute::Status ClDepthwiseConvolutionWorkloadValidate(const TensorInfo& input,
24 const TensorInfo& output,
25 const DepthwiseConvolution2dDescriptor& descriptor,
26 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +010027 const Optional<TensorInfo>& biases)
telsoa01c577f2c2018-08-31 09:22:23 +010028{
Nikhil Raja05c2102018-09-25 16:16:13 +010029 const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
30 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
31 const arm_compute::TensorInfo aclWeightsInfo = BuildArmComputeTensorInfo(weights, descriptor.m_DataLayout);
telsoa01c577f2c2018-08-31 09:22:23 +010032
33 arm_compute::TensorInfo aclBiasesInfo;
34 arm_compute::TensorInfo *optionalAclBiasesInfo = nullptr;
arovir01a6824102018-08-28 17:40:45 +010035
telsoa01c577f2c2018-08-31 09:22:23 +010036 if (descriptor.m_BiasEnabled)
37 {
David Beck5eec11d2018-10-04 15:43:17 +010038 BOOST_ASSERT(biases.has_value());
arovir01a6824102018-08-28 17:40:45 +010039
David Beck5eec11d2018-10-04 15:43:17 +010040 aclBiasesInfo = BuildArmComputeTensorInfo(biases.value(), descriptor.m_DataLayout);
telsoa01c577f2c2018-08-31 09:22:23 +010041 optionalAclBiasesInfo = &aclBiasesInfo;
42 }
43
44 const arm_compute::PadStrideInfo aclPadStrideInfo = BuildArmComputePadStrideInfo(descriptor);
45 const unsigned int aclDepthMultiplier = weights.GetShape()[0];
46
47 return arm_compute::CLDepthwiseConvolutionLayer::validate(&aclInputInfo,
48 &aclWeightsInfo,
49 optionalAclBiasesInfo,
50 &aclOutputInfo,
51 aclPadStrideInfo,
52 aclDepthMultiplier);
53}
54
Matthew Benthamd8777392018-10-08 09:38:55 +010055ClDepthwiseConvolutionWorkload::ClDepthwiseConvolutionWorkload(
telsoa01c577f2c2018-08-31 09:22:23 +010056 const DepthwiseConvolution2dQueueDescriptor& descriptor,
57 const WorkloadInfo& info)
Matthew Benthamd8777392018-10-08 09:38:55 +010058 : BaseWorkload<DepthwiseConvolution2dQueueDescriptor>(descriptor, info)
telsoa01c577f2c2018-08-31 09:22:23 +010059{
60 auto& weightInfo = m_Data.m_Weight->GetTensorInfo();
61
62 m_KernelTensor = std::make_unique<arm_compute::CLTensor>();
63 BuildArmComputeTensor(*m_KernelTensor, weightInfo);
64
65 if (m_Data.m_Parameters.m_BiasEnabled)
66 {
67 m_BiasTensor = std::make_unique<arm_compute::CLTensor>();
68 BuildArmComputeTensor(*m_BiasTensor, m_Data.m_Bias->GetTensorInfo());
69 }
70
71 arm_compute::PadStrideInfo padStrideInfo(m_Data.m_Parameters.m_StrideX,
72 m_Data.m_Parameters.m_StrideY,
73 m_Data.m_Parameters.m_PadLeft,
74 m_Data.m_Parameters.m_PadRight,
75 m_Data.m_Parameters.m_PadTop,
76 m_Data.m_Parameters.m_PadBottom,
77 arm_compute::DimensionRoundingType::FLOOR);
78
Matthew Benthamd8777392018-10-08 09:38:55 +010079 std::string name = std::string("ClDepthwiseConvolutionWorkload");
telsoa01c577f2c2018-08-31 09:22:23 +010080 m_Data.ValidateInputsOutputs(name, 1, 1);
81
82 arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
83 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
84
85 const unsigned int depthMultiplier = weightInfo.GetShape()[0];
86
87 //Check for optimisation opportunities.
88 bool use3x3Optimisation = (weightInfo.GetShape()[3] == 3) && (weightInfo.GetShape()[2] == 3);
89 if (use3x3Optimisation)
90 {
91 m_DepthwiseConvolutionLayer = std::make_unique<arm_compute::CLDepthwiseConvolutionLayer3x3>();
92 static_cast<arm_compute::CLDepthwiseConvolutionLayer3x3*>(m_DepthwiseConvolutionLayer.get())->configure(
93 &input,
94 m_KernelTensor.get(),
95 m_BiasTensor.get(),
96 &output,
97 padStrideInfo,
98 depthMultiplier);
99 }
100 else
101 {
102 m_DepthwiseConvolutionLayer = std::make_unique<arm_compute::CLDepthwiseConvolutionLayer>();
103 static_cast<arm_compute::CLDepthwiseConvolutionLayer*>(m_DepthwiseConvolutionLayer.get())->configure(
104 &input,
105 m_KernelTensor.get(),
106 m_BiasTensor.get(),
107 &output,
108 padStrideInfo,
109 depthMultiplier);
110 }
111
112 BOOST_ASSERT(m_DepthwiseConvolutionLayer);
Matthew Benthamd8777392018-10-08 09:38:55 +0100113
114 InitializeArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight);
115
116 if (m_BiasTensor)
117 {
118 InitializeArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias);
119 }
120
121 m_DepthwiseConvolutionLayer->prepare();
122 FreeUnusedTensors();
telsoa01c577f2c2018-08-31 09:22:23 +0100123}
124
Matthew Benthamd8777392018-10-08 09:38:55 +0100125void ClDepthwiseConvolutionWorkload::FreeUnusedTensors()
telsoa01c577f2c2018-08-31 09:22:23 +0100126{
127 FreeTensorIfUnused(m_KernelTensor);
128 FreeTensorIfUnused(m_BiasTensor);
129}
130
Matthew Benthamd8777392018-10-08 09:38:55 +0100131void ClDepthwiseConvolutionWorkload::Execute() const
132{
133 ARMNN_SCOPED_PROFILING_EVENT_CL("ClDepthwiseConvolutionWorkload_Execute");
134 BOOST_ASSERT(m_DepthwiseConvolutionLayer);
135
136 m_DepthwiseConvolutionLayer->run();
137}
telsoa01c577f2c2018-08-31 09:22:23 +0100138
139} // namespace armnn