blob: 1080f320e7cd1dafdcc2def0bc5ebe078600382e [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
Nattapat Chaimanowong974b65f2018-10-15 15:07:34 +01006#include "NeonConvolution2dWorkload.hpp"
7
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00008#include <backendsCommon/CpuTensorHandle.hpp>
9#include <aclCommon/ArmComputeTensorUtils.hpp>
Matthew Benthamd80a7122019-01-08 17:52:37 +000010#include <neon/workloads/NeonWorkloadUtils.hpp>
11
12#include <arm_compute/runtime/NEON/functions/NEConvolutionLayer.h>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck711fa312018-09-24 10:46:38 +010014#include <armnn/Types.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000015#include <Half.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010016
telsoa014fcda012018-03-09 14:13:49 +000017namespace armnn
18{
19
surmeh013537c2c2018-05-18 16:31:43 +010020using namespace armcomputetensorutils;
21
22arm_compute::Status NeonConvolution2dWorkloadValidate(const TensorInfo& input,
23 const TensorInfo& output,
24 const Convolution2dDescriptor& descriptor,
25 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +010026 const Optional<TensorInfo>& biases)
surmeh013537c2c2018-05-18 16:31:43 +010027{
Francis Murtagh351d13d2018-09-24 15:01:18 +010028 const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
29 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
30 const arm_compute::TensorInfo aclWeightsInfo = BuildArmComputeTensorInfo(weights, descriptor.m_DataLayout);
arovir01a6824102018-08-28 17:40:45 +010031
surmeh013537c2c2018-05-18 16:31:43 +010032 arm_compute::TensorInfo aclBiasesInfo;
33 arm_compute::TensorInfo *optionalAclBiasesInfo = nullptr;
34
35 if (descriptor.m_BiasEnabled)
36 {
David Beck5eec11d2018-10-04 15:43:17 +010037 BOOST_ASSERT(biases.has_value());
arovir01a6824102018-08-28 17:40:45 +010038
David Beck5eec11d2018-10-04 15:43:17 +010039 aclBiasesInfo = BuildArmComputeTensorInfo(biases.value(), descriptor.m_DataLayout);
surmeh013537c2c2018-05-18 16:31:43 +010040 optionalAclBiasesInfo = &aclBiasesInfo;
41 }
42
43 arm_compute::PadStrideInfo layerInfo = BuildArmComputePadStrideInfo(descriptor);
44
45 return arm_compute::NEConvolutionLayer::validate(&aclInputInfo,
46 &aclWeightsInfo,
47 optionalAclBiasesInfo,
48 &aclOutputInfo,
49 layerInfo);
50}
51
Nattapat Chaimanowong974b65f2018-10-15 15:07:34 +010052NeonConvolution2dWorkload::NeonConvolution2dWorkload(
telsoa01c577f2c2018-08-31 09:22:23 +010053 const Convolution2dQueueDescriptor& descriptor, const WorkloadInfo& info,
54 std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
Nattapat Chaimanowong974b65f2018-10-15 15:07:34 +010055 : BaseWorkload<Convolution2dQueueDescriptor>(descriptor, info)
telsoa014fcda012018-03-09 14:13:49 +000056{
57 using arm_compute::NEDirectConvolutionLayer;
telsoa014fcda012018-03-09 14:13:49 +000058
Nattapat Chaimanowong974b65f2018-10-15 15:07:34 +010059 m_Data.ValidateInputsOutputs("NeonConvolution2dWorkload", 1, 1);
telsoa014fcda012018-03-09 14:13:49 +000060
telsoa01c577f2c2018-08-31 09:22:23 +010061 // todo: check tensor shapes match.
telsoa014fcda012018-03-09 14:13:49 +000062
63 arm_compute::ITensor& input = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
64 arm_compute::ITensor& output = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
65
Francis Murtaghd59116e2018-10-04 16:03:07 +010066 arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
67 input.info()->set_data_layout(aclDataLayout);
68 output.info()->set_data_layout(aclDataLayout);
69
telsoa01c577f2c2018-08-31 09:22:23 +010070 m_KernelTensor = std::make_unique<arm_compute::Tensor>();
Francis Murtaghd59116e2018-10-04 16:03:07 +010071 BuildArmComputeTensor(*m_KernelTensor, m_Data.m_Weight->GetTensorInfo(), m_Data.m_Parameters.m_DataLayout);
telsoa014fcda012018-03-09 14:13:49 +000072
telsoa014fcda012018-03-09 14:13:49 +000073 if (m_Data.m_Parameters.m_BiasEnabled)
74 {
telsoa01c577f2c2018-08-31 09:22:23 +010075 m_BiasTensor = std::make_unique<arm_compute::Tensor>();
Francis Murtaghd59116e2018-10-04 16:03:07 +010076 BuildArmComputeTensor(*m_BiasTensor, m_Data.m_Bias->GetTensorInfo(), m_Data.m_Parameters.m_DataLayout);
telsoa014fcda012018-03-09 14:13:49 +000077 }
78
79 arm_compute::PadStrideInfo padStrideInfo(m_Data.m_Parameters.m_StrideX,
80 m_Data.m_Parameters.m_StrideY,
81 m_Data.m_Parameters.m_PadLeft,
82 m_Data.m_Parameters.m_PadRight,
83 m_Data.m_Parameters.m_PadTop,
84 m_Data.m_Parameters.m_PadBottom,
85 arm_compute::DimensionRoundingType::FLOOR);
86
narpra01fca75c32018-11-16 12:38:41 +000087 auto convolutionLayer = std::make_unique<arm_compute::NEConvolutionLayer>(memoryManager);
88 convolutionLayer->configure(&input,
89 m_KernelTensor.get(),
90 m_BiasTensor.get(),
91 &output,
92 padStrideInfo);
93 m_ConvolutionLayer.reset(convolutionLayer.release());
telsoa014fcda012018-03-09 14:13:49 +000094
telsoa014fcda012018-03-09 14:13:49 +000095 BOOST_ASSERT(m_ConvolutionLayer);
96
Nattapat Chaimanowong177d8d22018-10-16 13:21:27 +010097 InitializeArmComputeTensorData(*m_KernelTensor, m_Data.m_Weight);
telsoa014fcda012018-03-09 14:13:49 +000098
Nattapat Chaimanowong974b65f2018-10-15 15:07:34 +010099 if (m_Data.m_Parameters.m_BiasEnabled)
100 {
101 InitializeArmComputeTensorData(*m_BiasTensor, m_Data.m_Bias);
102 }
103
104 m_ConvolutionLayer->prepare();
105 FreeUnusedTensors();
telsoa014fcda012018-03-09 14:13:49 +0000106}
107
Nattapat Chaimanowong974b65f2018-10-15 15:07:34 +0100108void NeonConvolution2dWorkload::Execute() const
109{
110 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonConvolution2dWorkload_Execute");
111 m_ConvolutionLayer->run();
112}
113
114void NeonConvolution2dWorkload::FreeUnusedTensors()
telsoa01c577f2c2018-08-31 09:22:23 +0100115{
116 FreeTensorIfUnused(m_KernelTensor);
117 FreeTensorIfUnused(m_BiasTensor);
118}
119
telsoa014fcda012018-03-09 14:13:49 +0000120} //namespace armnn