blob: dc7030ef81a10c2a4e4446c5a9b4042c8bb37aeb [file] [log] [blame]
Francis Murtagh43aec582019-05-27 12:14:10 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefFullyConnectedWorkload.hpp"
7
8#include "FullyConnected.hpp"
9#include "RefWorkloadUtils.hpp"
10
11#include "Profiling.hpp"
12
13namespace armnn
14{
15RefFullyConnectedWorkload::RefFullyConnectedWorkload(
16 const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
17 : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info),
18 m_Weight(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight)))
19{
20 const TensorInfo& rWeightInfo = GetTensorInfo(m_Weight.get());
21 m_WeightShape = rWeightInfo.GetShape();
22 m_WeightDecoder = MakeDecoder<float>(rWeightInfo, m_Weight->Map(true));
23
24 if (descriptor.m_Parameters.m_BiasEnabled)
25 {
26 m_Bias = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias));
27 const TensorInfo& biasInfo = GetTensorInfo(m_Bias.get());
28 m_BiasDecoder = MakeDecoder<float>(biasInfo, m_Bias->Map(true));
29 }
30}
31
32void RefFullyConnectedWorkload::PostAllocationConfigure()
33{
34 const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
35 BOOST_ASSERT(inputInfo.GetNumDimensions() > 1);
36 m_InputShape = inputInfo.GetShape();
37 m_InputDecoder = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map());
38
39 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
40 m_OutputShape = outputInfo.GetShape();
41 m_OutputEncoder = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
42
43 m_NumActivations = 1; // Total number of activations in the input.
44 for (unsigned int i = 1; i < inputInfo.GetNumDimensions(); i++)
45 {
46 m_NumActivations *= inputInfo.GetShape()[i];
47 }
48}
49
50void RefFullyConnectedWorkload::Execute() const
51{
52 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefFullyConnectedWorkload_Execute");
53
54 FullyConnected(m_InputShape,
55 *m_InputDecoder,
56 m_OutputShape,
57 *m_OutputEncoder,
58 *m_WeightDecoder,
59 *m_BiasDecoder,
60 m_Data.m_Parameters.m_BiasEnabled,
61 m_NumActivations,
62 m_Data.m_Parameters.m_TransposeWeightMatrix);
63}
64
65} //namespace armnn