blob: 087fc9da68f186e4ca46342ee0d74e7d7fd32b8a [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Francis Murtagh43aec582019-05-27 12:14:10 +01002// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefFullyConnectedWorkload.hpp"
7
8#include "FullyConnected.hpp"
9#include "RefWorkloadUtils.hpp"
10
11#include "Profiling.hpp"
12
13namespace armnn
14{
Finn Williamsd0420cb2022-05-18 13:24:14 +010015
16unsigned int GetNumActivations(const TensorInfo& inputInfo)
17{
18 unsigned int numActivations = 1; // Total number of activations in the input.
19 for (unsigned int i = 1; i < inputInfo.GetNumDimensions(); i++)
20 {
21 numActivations *= inputInfo.GetShape()[i];
22 }
23 return numActivations;
24}
25
26
Francis Murtagh43aec582019-05-27 12:14:10 +010027RefFullyConnectedWorkload::RefFullyConnectedWorkload(
28 const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
Finn Williams73c547d2022-02-15 20:47:34 +000029 : RefBaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
Finn Williamsd0420cb2022-05-18 13:24:14 +010030 , m_InputShape(info.m_InputTensorInfos[0].GetShape())
31 , m_WeightShape(info.m_InputTensorInfos[1].GetShape())
32 , m_OutputShape(info.m_OutputTensorInfos[0].GetShape())
33 , m_NumActivations(GetNumActivations(info.m_InputTensorInfos[0]))
Francis Murtagh43aec582019-05-27 12:14:10 +010034{
Francis Murtagh43aec582019-05-27 12:14:10 +010035}
36
Francis Murtagh43aec582019-05-27 12:14:10 +010037void RefFullyConnectedWorkload::Execute() const
38{
Finn Williamsb8181f72021-04-07 10:23:21 +010039 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
40}
41
42void RefFullyConnectedWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor)
43{
Finn Williamsb8181f72021-04-07 10:23:21 +010044 Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
45}
46
47void RefFullyConnectedWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
48{
Francis Murtagh43aec582019-05-27 12:14:10 +010049 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefFullyConnectedWorkload_Execute");
50
Finn Williamsb8181f72021-04-07 10:23:21 +010051 std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]), inputs[0]->Map());
52 std::unique_ptr<Encoder<float>> OutputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]), outputs[0]->Map());
53
Finn Williamsd0420cb2022-05-18 13:24:14 +010054 std::unique_ptr<Decoder<float>> weightsDecoder = MakeDecoder<float>(GetTensorInfo(inputs[1]), inputs[1]->Map());
55 std::unique_ptr<Decoder<float>> biasDecoder;
56
Matthew Sloyan81beae32021-07-13 19:46:11 +010057 if (m_Data.m_Parameters.m_BiasEnabled)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000058 {
Finn Williamsd0420cb2022-05-18 13:24:14 +010059 biasDecoder = MakeDecoder<float>(GetTensorInfo(inputs[2]), inputs[2]->Map());
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000060 }
Matthew Benthamc394a6d2019-06-24 12:51:25 +010061
Francis Murtagh43aec582019-05-27 12:14:10 +010062 FullyConnected(m_InputShape,
Finn Williamsb8181f72021-04-07 10:23:21 +010063 *inputDecoder,
Francis Murtagh43aec582019-05-27 12:14:10 +010064 m_OutputShape,
Finn Williamsb8181f72021-04-07 10:23:21 +010065 *OutputEncoder,
Finn Williamsb9dcfe62020-09-17 15:58:31 +010066 m_WeightShape,
Finn Williamsd0420cb2022-05-18 13:24:14 +010067 *weightsDecoder,
68 biasDecoder.get(),
Francis Murtagh43aec582019-05-27 12:14:10 +010069 m_Data.m_Parameters.m_BiasEnabled,
70 m_NumActivations,
71 m_Data.m_Parameters.m_TransposeWeightMatrix);
72}
73
74} //namespace armnn