blob: ca7a0c575af58ae2285177ab825e41a36c935f92 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
arovir019e53a352018-08-31 15:26:35 +01006#include "NeonConvolution2dFloatWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +00007#include "backends/CpuTensorHandle.hpp"
8#include "backends/ArmComputeTensorUtils.hpp"
9#include "backends/NeonLayerSupport.hpp"
10
11namespace armnn
12{
13using namespace armcomputetensorutils;
14
arovir019e53a352018-08-31 15:26:35 +010015NeonConvolution2dFloatWorkload::NeonConvolution2dFloatWorkload(const Convolution2dQueueDescriptor& descriptor,
surmeh013537c2c2018-05-18 16:31:43 +010016 const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
17 : NeonConvolution2dBaseWorkload(descriptor, info, memoryManager)
surmeh01bceff2f2018-03-29 16:29:27 +010018{
19 if (m_Data.m_Parameters.m_BiasEnabled)
20 {
telsoa01c577f2c2018-08-31 09:22:23 +010021 InitializeArmComputeTensorDataForFloatTypes(*m_BiasTensor, m_Data.m_Bias);
surmeh01bceff2f2018-03-29 16:29:27 +010022 }
telsoa01c577f2c2018-08-31 09:22:23 +010023
24 m_ConvolutionLayer->prepare();
25 FreeUnusedTensors();
surmeh01bceff2f2018-03-29 16:29:27 +010026}
telsoa014fcda012018-03-09 14:13:49 +000027
arovir019e53a352018-08-31 15:26:35 +010028void NeonConvolution2dFloatWorkload::Execute() const
telsoa014fcda012018-03-09 14:13:49 +000029{
arovir019e53a352018-08-31 15:26:35 +010030 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonConvolution2dFloatWorkload_Execute");
telsoa014fcda012018-03-09 14:13:49 +000031 m_ConvolutionLayer->run();
32}
33
arovir019e53a352018-08-31 15:26:35 +010034void NeonConvolution2dFloatWorkload::ValidateData() const
telsoa014fcda012018-03-09 14:13:49 +000035{
arovir019e53a352018-08-31 15:26:35 +010036 m_Data.ValidateInputsOutputs("NeonConvolution2dFloatWorkload", 1, 1);
telsoa014fcda012018-03-09 14:13:49 +000037}
38
telsoa014fcda012018-03-09 14:13:49 +000039} //namespace armnn
40