blob: 338c7eb1f6a97a6eef82fdb438c1025c67744372 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
kevmay01e448be32018-09-26 10:21:55 +01006#include "NeonFullyConnectedWorkload.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01007
Matthew Benthamd80a7122019-01-08 17:52:37 +00008#include "NeonWorkloadUtils.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <aclCommon/ArmComputeTensorUtils.hpp>
10#include <aclCommon/ArmComputeUtils.hpp>
11#include <backendsCommon/CpuTensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Matthew Benthamd80a7122019-01-08 17:52:37 +000013#include <arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h>
14
telsoa014fcda012018-03-09 14:13:49 +000015namespace armnn
16{
17using namespace armcomputetensorutils;
18
telsoa01c577f2c2018-08-31 09:22:23 +010019arm_compute::Status NeonFullyConnectedWorkloadValidate(const TensorInfo& input,
20 const TensorInfo& output,
21 const TensorInfo& weights,
22 const TensorInfo& biases,
23 const FullyConnectedDescriptor& descriptor)
24{
25 const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input);
26 const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output);
27 const arm_compute::TensorInfo aclWeights = BuildArmComputeTensorInfo(weights);
28
29 arm_compute::TensorInfo aclBiases;
30 arm_compute::TensorInfo *optionalAclBiases = nullptr;
31 if (descriptor.m_BiasEnabled)
32 {
33 aclBiases = BuildArmComputeTensorInfo(biases);
34 optionalAclBiases = &aclBiases;
35 }
36
37 const arm_compute::FullyConnectedLayerInfo fullyConnectedLayerInfo =
38 ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(descriptor);
39
40
41 return arm_compute::NEFullyConnectedLayer::validate(&aclInput,
42 &aclWeights,
43 optionalAclBiases,
44 &aclOutput,
45 fullyConnectedLayerInfo);
46}
47
kevmay01e448be32018-09-26 10:21:55 +010048NeonFullyConnectedWorkload::NeonFullyConnectedWorkload(const FullyConnectedQueueDescriptor& descriptor,
surmeh013537c2c2018-05-18 16:31:43 +010049 const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
kevmay01e448be32018-09-26 10:21:55 +010050 : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
telsoa014fcda012018-03-09 14:13:49 +000051{
kevmay01e448be32018-09-26 10:21:55 +010052 m_Data.ValidateInputsOutputs("NeonFullyConnectedWorkload", 1, 1);
telsoa014fcda012018-03-09 14:13:49 +000053
Derek Lambertic81855f2019-06-13 17:34:19 +010054 arm_compute::ITensor& input = boost::polymorphic_downcast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
55 arm_compute::ITensor& output = boost::polymorphic_downcast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
telsoa014fcda012018-03-09 14:13:49 +000056
telsoa01c577f2c2018-08-31 09:22:23 +010057 m_WeightsTensor = std::make_unique<arm_compute::Tensor>();
58 BuildArmComputeTensor(*m_WeightsTensor, m_Data.m_Weight->GetTensorInfo());
telsoa014fcda012018-03-09 14:13:49 +000059
telsoa014fcda012018-03-09 14:13:49 +000060 if (m_Data.m_Parameters.m_BiasEnabled)
61 {
telsoa01c577f2c2018-08-31 09:22:23 +010062 m_BiasesTensor = std::make_unique<arm_compute::Tensor>();
63 BuildArmComputeTensor(*m_BiasesTensor, m_Data.m_Bias->GetTensorInfo());
telsoa014fcda012018-03-09 14:13:49 +000064 }
65
66 // Construct
telsoa01c577f2c2018-08-31 09:22:23 +010067 arm_compute::FullyConnectedLayerInfo fc_info;
68 fc_info.transpose_weights = m_Data.m_Parameters.m_TransposeWeightMatrix;
Matthew Benthamd80a7122019-01-08 17:52:37 +000069
70 auto layer = std::make_unique<arm_compute::NEFullyConnectedLayer>(memoryManager);
71 layer->configure(&input, m_WeightsTensor.get(), m_BiasesTensor.get(), &output, fc_info);
72 m_FullyConnectedLayer.reset(layer.release());
telsoa014fcda012018-03-09 14:13:49 +000073
74 // Allocate
Derek Lambertif90c56d2020-01-10 17:14:08 +000075 if (m_Data.m_Weight->GetTensorInfo().GetDataType() == DataType::QAsymmU8)
kevmay01e448be32018-09-26 10:21:55 +010076 {
Nattapat Chaimanowong177d8d22018-10-16 13:21:27 +010077 InitializeArmComputeTensorData(*m_WeightsTensor, m_Data.m_Weight);
kevmay01e448be32018-09-26 10:21:55 +010078 }
79 else
80 {
Nattapat Chaimanowong177d8d22018-10-16 13:21:27 +010081 InitializeArmComputeTensorData(*m_WeightsTensor, m_Data.m_Weight);
kevmay01e448be32018-09-26 10:21:55 +010082 }
telsoa014fcda012018-03-09 14:13:49 +000083
telsoa01c577f2c2018-08-31 09:22:23 +010084 if (m_BiasesTensor)
telsoa014fcda012018-03-09 14:13:49 +000085 {
kevmay01e448be32018-09-26 10:21:55 +010086 if (m_Data.m_Bias->GetTensorInfo().GetDataType() == DataType::Signed32)
87 {
Nattapat Chaimanowong177d8d22018-10-16 13:21:27 +010088 InitializeArmComputeTensorData(*m_BiasesTensor, m_Data.m_Bias);
kevmay01e448be32018-09-26 10:21:55 +010089 }
90 else
91 {
Nattapat Chaimanowong177d8d22018-10-16 13:21:27 +010092 InitializeArmComputeTensorData(*m_BiasesTensor, m_Data.m_Bias);
kevmay01e448be32018-09-26 10:21:55 +010093 }
telsoa014fcda012018-03-09 14:13:49 +000094 }
telsoa01c577f2c2018-08-31 09:22:23 +010095
96 // Force Compute Library to perform the necessary copying and reshaping, after which
97 // delete all the input tensors that will no longer be needed
Matthew Benthamd80a7122019-01-08 17:52:37 +000098 m_FullyConnectedLayer->prepare();
telsoa01c577f2c2018-08-31 09:22:23 +010099 FreeUnusedTensors();
telsoa014fcda012018-03-09 14:13:49 +0000100}
101
kevmay01e448be32018-09-26 10:21:55 +0100102void NeonFullyConnectedWorkload::Execute() const
telsoa014fcda012018-03-09 14:13:49 +0000103{
kevmay01e448be32018-09-26 10:21:55 +0100104 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonFullyConnectedWorkload_Execute");
Matthew Benthamd80a7122019-01-08 17:52:37 +0000105 m_FullyConnectedLayer->run();
telsoa014fcda012018-03-09 14:13:49 +0000106}
107
kevmay01e448be32018-09-26 10:21:55 +0100108void NeonFullyConnectedWorkload::FreeUnusedTensors()
telsoa01c577f2c2018-08-31 09:22:23 +0100109{
110 FreeTensorIfUnused(m_WeightsTensor);
111 FreeTensorIfUnused(m_BiasesTensor);
112}
113
telsoa014fcda012018-03-09 14:13:49 +0000114} //namespace armnn