blob: f8c3548905c1d8a0261fe519ddeeae15ea41575d [file] [log] [blame]
Francis Murtagh43aec582019-05-27 12:14:10 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefFullyConnectedWorkload.hpp"
7
8#include "FullyConnected.hpp"
9#include "RefWorkloadUtils.hpp"
10
11#include "Profiling.hpp"
12
13namespace armnn
14{
15RefFullyConnectedWorkload::RefFullyConnectedWorkload(
16 const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
17 : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info),
18 m_Weight(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight)))
19{
Matthew Bentham4cefc412019-06-18 16:14:34 +010020 const TensorInfo& rWeightInfo = m_Weight->GetTensorInfo();
Francis Murtagh43aec582019-05-27 12:14:10 +010021 m_WeightShape = rWeightInfo.GetShape();
22 m_WeightDecoder = MakeDecoder<float>(rWeightInfo, m_Weight->Map(true));
23
24 if (descriptor.m_Parameters.m_BiasEnabled)
25 {
26 m_Bias = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias));
Matthew Bentham4cefc412019-06-18 16:14:34 +010027 const TensorInfo& biasInfo = m_Bias->GetTensorInfo();
Francis Murtagh43aec582019-05-27 12:14:10 +010028 m_BiasDecoder = MakeDecoder<float>(biasInfo, m_Bias->Map(true));
29 }
30}
31
32void RefFullyConnectedWorkload::PostAllocationConfigure()
33{
34 const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010035 ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1);
Francis Murtagh43aec582019-05-27 12:14:10 +010036 m_InputShape = inputInfo.GetShape();
Matthew Benthamc394a6d2019-06-24 12:51:25 +010037 m_InputDecoder = MakeDecoder<float>(inputInfo);
Francis Murtagh43aec582019-05-27 12:14:10 +010038
39 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
40 m_OutputShape = outputInfo.GetShape();
Matthew Benthamc394a6d2019-06-24 12:51:25 +010041 m_OutputEncoder = MakeEncoder<float>(outputInfo);
Francis Murtagh43aec582019-05-27 12:14:10 +010042
43 m_NumActivations = 1; // Total number of activations in the input.
44 for (unsigned int i = 1; i < inputInfo.GetNumDimensions(); i++)
45 {
46 m_NumActivations *= inputInfo.GetShape()[i];
47 }
48}
49
50void RefFullyConnectedWorkload::Execute() const
51{
52 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefFullyConnectedWorkload_Execute");
53
Matthew Benthamc394a6d2019-06-24 12:51:25 +010054 m_InputDecoder->Reset(m_Data.m_Inputs[0]->Map());
55 m_OutputEncoder->Reset(m_Data.m_Outputs[0]->Map());
56
Francis Murtagh43aec582019-05-27 12:14:10 +010057 FullyConnected(m_InputShape,
58 *m_InputDecoder,
59 m_OutputShape,
60 *m_OutputEncoder,
61 *m_WeightDecoder,
62 *m_BiasDecoder,
63 m_Data.m_Parameters.m_BiasEnabled,
64 m_NumActivations,
65 m_Data.m_Parameters.m_TransposeWeightMatrix);
66}
67
68} //namespace armnn