blob: e808c60c0c0f2d05b58e1011fac522d111cb1637 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
kevmay01e448be32018-09-26 10:21:55 +01006#include "NeonFullyConnectedWorkload.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01007
Matthew Benthamd80a7122019-01-08 17:52:37 +00008#include "NeonWorkloadUtils.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <aclCommon/ArmComputeTensorUtils.hpp>
10#include <aclCommon/ArmComputeUtils.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010011#include <armnn/utility/PolymorphicDowncast.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <backendsCommon/CpuTensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matthew Benthamd80a7122019-01-08 17:52:37 +000014#include <arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h>
15
telsoa014fcda012018-03-09 14:13:49 +000016namespace armnn
17{
18using namespace armcomputetensorutils;
19
telsoa01c577f2c2018-08-31 09:22:23 +010020arm_compute::Status NeonFullyConnectedWorkloadValidate(const TensorInfo& input,
21 const TensorInfo& output,
22 const TensorInfo& weights,
23 const TensorInfo& biases,
24 const FullyConnectedDescriptor& descriptor)
25{
26 const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input);
27 const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output);
28 const arm_compute::TensorInfo aclWeights = BuildArmComputeTensorInfo(weights);
29
30 arm_compute::TensorInfo aclBiases;
31 arm_compute::TensorInfo *optionalAclBiases = nullptr;
32 if (descriptor.m_BiasEnabled)
33 {
34 aclBiases = BuildArmComputeTensorInfo(biases);
35 optionalAclBiases = &aclBiases;
36 }
37
38 const arm_compute::FullyConnectedLayerInfo fullyConnectedLayerInfo =
39 ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(descriptor);
40
41
42 return arm_compute::NEFullyConnectedLayer::validate(&aclInput,
43 &aclWeights,
44 optionalAclBiases,
45 &aclOutput,
46 fullyConnectedLayerInfo);
47}
48
kevmay01e448be32018-09-26 10:21:55 +010049NeonFullyConnectedWorkload::NeonFullyConnectedWorkload(const FullyConnectedQueueDescriptor& descriptor,
surmeh013537c2c2018-05-18 16:31:43 +010050 const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
kevmay01e448be32018-09-26 10:21:55 +010051 : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
telsoa014fcda012018-03-09 14:13:49 +000052{
kevmay01e448be32018-09-26 10:21:55 +010053 m_Data.ValidateInputsOutputs("NeonFullyConnectedWorkload", 1, 1);
telsoa014fcda012018-03-09 14:13:49 +000054
Jan Eilersbb446e52020-04-02 13:56:54 +010055 arm_compute::ITensor& input = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
56 arm_compute::ITensor& output = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
telsoa014fcda012018-03-09 14:13:49 +000057
telsoa01c577f2c2018-08-31 09:22:23 +010058 m_WeightsTensor = std::make_unique<arm_compute::Tensor>();
59 BuildArmComputeTensor(*m_WeightsTensor, m_Data.m_Weight->GetTensorInfo());
telsoa014fcda012018-03-09 14:13:49 +000060
telsoa014fcda012018-03-09 14:13:49 +000061 if (m_Data.m_Parameters.m_BiasEnabled)
62 {
telsoa01c577f2c2018-08-31 09:22:23 +010063 m_BiasesTensor = std::make_unique<arm_compute::Tensor>();
64 BuildArmComputeTensor(*m_BiasesTensor, m_Data.m_Bias->GetTensorInfo());
telsoa014fcda012018-03-09 14:13:49 +000065 }
66
67 // Construct
telsoa01c577f2c2018-08-31 09:22:23 +010068 arm_compute::FullyConnectedLayerInfo fc_info;
69 fc_info.transpose_weights = m_Data.m_Parameters.m_TransposeWeightMatrix;
Matthew Benthamd80a7122019-01-08 17:52:37 +000070
71 auto layer = std::make_unique<arm_compute::NEFullyConnectedLayer>(memoryManager);
72 layer->configure(&input, m_WeightsTensor.get(), m_BiasesTensor.get(), &output, fc_info);
73 m_FullyConnectedLayer.reset(layer.release());
telsoa014fcda012018-03-09 14:13:49 +000074
75 // Allocate
Derek Lambertif90c56d2020-01-10 17:14:08 +000076 if (m_Data.m_Weight->GetTensorInfo().GetDataType() == DataType::QAsymmU8)
kevmay01e448be32018-09-26 10:21:55 +010077 {
Nattapat Chaimanowong177d8d22018-10-16 13:21:27 +010078 InitializeArmComputeTensorData(*m_WeightsTensor, m_Data.m_Weight);
kevmay01e448be32018-09-26 10:21:55 +010079 }
80 else
81 {
Nattapat Chaimanowong177d8d22018-10-16 13:21:27 +010082 InitializeArmComputeTensorData(*m_WeightsTensor, m_Data.m_Weight);
kevmay01e448be32018-09-26 10:21:55 +010083 }
telsoa014fcda012018-03-09 14:13:49 +000084
telsoa01c577f2c2018-08-31 09:22:23 +010085 if (m_BiasesTensor)
telsoa014fcda012018-03-09 14:13:49 +000086 {
kevmay01e448be32018-09-26 10:21:55 +010087 if (m_Data.m_Bias->GetTensorInfo().GetDataType() == DataType::Signed32)
88 {
Nattapat Chaimanowong177d8d22018-10-16 13:21:27 +010089 InitializeArmComputeTensorData(*m_BiasesTensor, m_Data.m_Bias);
kevmay01e448be32018-09-26 10:21:55 +010090 }
91 else
92 {
Nattapat Chaimanowong177d8d22018-10-16 13:21:27 +010093 InitializeArmComputeTensorData(*m_BiasesTensor, m_Data.m_Bias);
kevmay01e448be32018-09-26 10:21:55 +010094 }
telsoa014fcda012018-03-09 14:13:49 +000095 }
telsoa01c577f2c2018-08-31 09:22:23 +010096
97 // Force Compute Library to perform the necessary copying and reshaping, after which
98 // delete all the input tensors that will no longer be needed
Matthew Benthamd80a7122019-01-08 17:52:37 +000099 m_FullyConnectedLayer->prepare();
telsoa01c577f2c2018-08-31 09:22:23 +0100100 FreeUnusedTensors();
telsoa014fcda012018-03-09 14:13:49 +0000101}
102
kevmay01e448be32018-09-26 10:21:55 +0100103void NeonFullyConnectedWorkload::Execute() const
telsoa014fcda012018-03-09 14:13:49 +0000104{
kevmay01e448be32018-09-26 10:21:55 +0100105 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonFullyConnectedWorkload_Execute");
Matthew Benthamd80a7122019-01-08 17:52:37 +0000106 m_FullyConnectedLayer->run();
telsoa014fcda012018-03-09 14:13:49 +0000107}
108
kevmay01e448be32018-09-26 10:21:55 +0100109void NeonFullyConnectedWorkload::FreeUnusedTensors()
telsoa01c577f2c2018-08-31 09:22:23 +0100110{
111 FreeTensorIfUnused(m_WeightsTensor);
112 FreeTensorIfUnused(m_BiasesTensor);
113}
114
telsoa014fcda012018-03-09 14:13:49 +0000115} //namespace armnn