blob: 42737e2af649d00ac3682fae415d046be075c496 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Mike Kelly7cbe7812023-07-25 17:37:33 +01002// Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved.
Francis Murtagh43aec582019-05-27 12:14:10 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "RefFullyConnectedWorkload.hpp"
7
8#include "FullyConnected.hpp"
9#include "RefWorkloadUtils.hpp"
10
11#include "Profiling.hpp"
12
13namespace armnn
14{
Finn Williamsd0420cb2022-05-18 13:24:14 +010015
16unsigned int GetNumActivations(const TensorInfo& inputInfo)
17{
18 unsigned int numActivations = 1; // Total number of activations in the input.
19 for (unsigned int i = 1; i < inputInfo.GetNumDimensions(); i++)
20 {
21 numActivations *= inputInfo.GetShape()[i];
22 }
23 return numActivations;
24}
25
26
Francis Murtagh43aec582019-05-27 12:14:10 +010027RefFullyConnectedWorkload::RefFullyConnectedWorkload(
28 const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
Finn Williams73c547d2022-02-15 20:47:34 +000029 : RefBaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
Finn Williamsd0420cb2022-05-18 13:24:14 +010030 , m_InputShape(info.m_InputTensorInfos[0].GetShape())
31 , m_WeightShape(info.m_InputTensorInfos[1].GetShape())
32 , m_OutputShape(info.m_OutputTensorInfos[0].GetShape())
33 , m_NumActivations(GetNumActivations(info.m_InputTensorInfos[0]))
Francis Murtagh43aec582019-05-27 12:14:10 +010034{
Francis Murtagh43aec582019-05-27 12:14:10 +010035}
36
Francis Murtagh43aec582019-05-27 12:14:10 +010037void RefFullyConnectedWorkload::Execute() const
38{
Finn Williamsb8181f72021-04-07 10:23:21 +010039 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
40}
41
Matthew Sloyan2d213a72022-06-30 17:13:04 +010042void RefFullyConnectedWorkload::ExecuteAsync(ExecutionData& executionData)
Finn Williamsb8181f72021-04-07 10:23:21 +010043{
Matthew Sloyan2d213a72022-06-30 17:13:04 +010044 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
45 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
Finn Williamsb8181f72021-04-07 10:23:21 +010046}
47
48void RefFullyConnectedWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
49{
Mike Kelly7cbe7812023-07-25 17:37:33 +010050 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefFullyConnectedWorkload_Execute");
Francis Murtagh43aec582019-05-27 12:14:10 +010051
Finn Williamsb8181f72021-04-07 10:23:21 +010052 std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]), inputs[0]->Map());
53 std::unique_ptr<Encoder<float>> OutputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]), outputs[0]->Map());
54
Finn Williamsd0420cb2022-05-18 13:24:14 +010055 std::unique_ptr<Decoder<float>> weightsDecoder = MakeDecoder<float>(GetTensorInfo(inputs[1]), inputs[1]->Map());
56 std::unique_ptr<Decoder<float>> biasDecoder;
57
Matthew Sloyan81beae32021-07-13 19:46:11 +010058 if (m_Data.m_Parameters.m_BiasEnabled)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000059 {
Finn Williamsd0420cb2022-05-18 13:24:14 +010060 biasDecoder = MakeDecoder<float>(GetTensorInfo(inputs[2]), inputs[2]->Map());
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000061 }
Matthew Benthamc394a6d2019-06-24 12:51:25 +010062
Francis Murtagh43aec582019-05-27 12:14:10 +010063 FullyConnected(m_InputShape,
Finn Williamsb8181f72021-04-07 10:23:21 +010064 *inputDecoder,
Francis Murtagh43aec582019-05-27 12:14:10 +010065 m_OutputShape,
Finn Williamsb8181f72021-04-07 10:23:21 +010066 *OutputEncoder,
Finn Williamsb9dcfe62020-09-17 15:58:31 +010067 m_WeightShape,
Finn Williamsd0420cb2022-05-18 13:24:14 +010068 *weightsDecoder,
69 biasDecoder.get(),
Francis Murtagh43aec582019-05-27 12:14:10 +010070 m_Data.m_Parameters.m_BiasEnabled,
71 m_NumActivations,
72 m_Data.m_Parameters.m_TransposeWeightMatrix);
73}
74
75} //namespace armnn