blob: 49e105f206eecf07a15f132dd373bd4df4cfe423 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Francis Murtagh43aec582019-05-27 12:14:10 +01002// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefFullyConnectedWorkload.hpp"
7
8#include "FullyConnected.hpp"
9#include "RefWorkloadUtils.hpp"
10
11#include "Profiling.hpp"
12
13namespace armnn
14{
15RefFullyConnectedWorkload::RefFullyConnectedWorkload(
16 const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000017 : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
Francis Murtagh43aec582019-05-27 12:14:10 +010018{
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000019 if (descriptor.m_Parameters.m_ConstantWeights)
Francis Murtagh43aec582019-05-27 12:14:10 +010020 {
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000021 m_Weight = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight));
22 const TensorInfo& rWeightInfo = m_Weight->GetTensorInfo();
23 m_WeightShape = rWeightInfo.GetShape();
24 m_WeightDecoder = MakeDecoder<float>(rWeightInfo, m_Weight->Map(true));
25
26 if (descriptor.m_Parameters.m_BiasEnabled)
27 {
28 m_Bias = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias));
29 const TensorInfo& biasInfo = m_Bias->GetTensorInfo();
30 m_BiasDecoder = MakeDecoder<float>(biasInfo, m_Bias->Map(true));
31 }
Francis Murtagh43aec582019-05-27 12:14:10 +010032 }
33}
34
35void RefFullyConnectedWorkload::PostAllocationConfigure()
36{
37 const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010038 ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1);
Francis Murtagh43aec582019-05-27 12:14:10 +010039 m_InputShape = inputInfo.GetShape();
Matthew Benthamc394a6d2019-06-24 12:51:25 +010040 m_InputDecoder = MakeDecoder<float>(inputInfo);
Francis Murtagh43aec582019-05-27 12:14:10 +010041
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000042 if (!m_Data.m_Parameters.m_ConstantWeights)
43 {
44 const TensorInfo& rWeightInfo = GetTensorInfo(m_Data.m_Inputs[1]);
45 ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1);
46 m_WeightShape = rWeightInfo.GetShape();
47 m_WeightDecoder = MakeDecoder<float>(rWeightInfo);
48
49 if (m_Data.m_Parameters.m_BiasEnabled)
50 {
51 const TensorInfo& biasInfo = GetTensorInfo(m_Data.m_Inputs[2]);
52 m_BiasDecoder = MakeDecoder<float>(biasInfo);
53 }
54 }
55
Francis Murtagh43aec582019-05-27 12:14:10 +010056 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
57 m_OutputShape = outputInfo.GetShape();
Matthew Benthamc394a6d2019-06-24 12:51:25 +010058 m_OutputEncoder = MakeEncoder<float>(outputInfo);
Francis Murtagh43aec582019-05-27 12:14:10 +010059
60 m_NumActivations = 1; // Total number of activations in the input.
61 for (unsigned int i = 1; i < inputInfo.GetNumDimensions(); i++)
62 {
63 m_NumActivations *= inputInfo.GetShape()[i];
64 }
65}
66
67void RefFullyConnectedWorkload::Execute() const
68{
69 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefFullyConnectedWorkload_Execute");
70
Matthew Benthamc394a6d2019-06-24 12:51:25 +010071 m_InputDecoder->Reset(m_Data.m_Inputs[0]->Map());
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000072 if (!m_Data.m_Parameters.m_ConstantWeights)
73 {
74 m_WeightDecoder->Reset(m_Data.m_Inputs[1]->Map());
75 if (m_Data.m_Parameters.m_BiasEnabled)
76 {
77 m_BiasDecoder->Reset(m_Data.m_Inputs[2]->Map());
78 }
79 }
Matthew Benthamc394a6d2019-06-24 12:51:25 +010080 m_OutputEncoder->Reset(m_Data.m_Outputs[0]->Map());
81
Francis Murtagh43aec582019-05-27 12:14:10 +010082 FullyConnected(m_InputShape,
83 *m_InputDecoder,
84 m_OutputShape,
85 *m_OutputEncoder,
Finn Williamsb9dcfe62020-09-17 15:58:31 +010086 m_WeightShape,
Francis Murtagh43aec582019-05-27 12:14:10 +010087 *m_WeightDecoder,
88 *m_BiasDecoder,
89 m_Data.m_Parameters.m_BiasEnabled,
90 m_NumActivations,
91 m_Data.m_Parameters.m_TransposeWeightMatrix);
92}
93
94} //namespace armnn