blob: 19c01b89877129d4f3f0faca02cad5776f0862b9 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Colm Donelanb4ef1632024-02-01 15:00:43 +00002// Copyright © 2017, 2024 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "FullyConnected.hpp"
7
Francis Murtagh43aec582019-05-27 12:14:10 +01008#include "RefWorkloadUtils.hpp"
9
telsoa014fcda012018-03-09 14:13:49 +000010namespace armnn
11{
12
Francis Murtagh43aec582019-05-27 12:14:10 +010013void FullyConnected(const TensorShape& rInputShape,
14 Decoder<float>& rInputDecoder,
15 const TensorShape& rOutputShape,
16 Encoder<float>& rOutputEncoder,
Finn Williamsb9dcfe62020-09-17 15:58:31 +010017 const TensorShape& rWeightsShape,
Francis Murtagh43aec582019-05-27 12:14:10 +010018 Decoder<float>& rWeightDecoder,
Matthew Bentham18bf43d2021-07-07 09:08:48 +010019 Decoder<float>* pBiasDecoder,
Francis Murtagh43aec582019-05-27 12:14:10 +010020 const bool biasEnabled,
21 const unsigned int K,
22 const bool transposeWeights)
telsoa014fcda012018-03-09 14:13:49 +000023{
Francis Murtagh43aec582019-05-27 12:14:10 +010024 // Perform FullyConnected implementation
25 unsigned int outputSize = rOutputShape[1];
telsoa014fcda012018-03-09 14:13:49 +000026
Finn Williamsea8ce702020-09-29 19:54:00 +010027 const std::vector<float> decodedInputs = rInputDecoder.DecodeTensor(rInputShape);
28 const std::vector<float> decodedWeights = rWeightDecoder.DecodeTensor(rWeightsShape);
29
30 const TensorShape biasShape{outputSize};
Matthew Bentham18bf43d2021-07-07 09:08:48 +010031
Matthew Bentham18bf43d2021-07-07 09:08:48 +010032 const std::vector<float> decodedBiases = biasEnabled ? pBiasDecoder->DecodeTensor(biasShape) : std::vector<float>();
Finn Williamsb9dcfe62020-09-17 15:58:31 +010033
34
Francis Murtagh43aec582019-05-27 12:14:10 +010035 for (unsigned int n = 0; n < rInputShape[0]; n++)
telsoa014fcda012018-03-09 14:13:49 +000036 {
Francis Murtagh43aec582019-05-27 12:14:10 +010037 for (unsigned int channelOutput = 0; channelOutput < outputSize; channelOutput++)
telsoa014fcda012018-03-09 14:13:49 +000038 {
39 float outval = 0.f;
40
41 for (unsigned int channelInput = 0; channelInput < K; channelInput++)
42 {
43 float weight;
44 if (transposeWeights)
45 {
Finn Williamsb9dcfe62020-09-17 15:58:31 +010046 weight = decodedWeights[channelOutput * K + channelInput];
telsoa014fcda012018-03-09 14:13:49 +000047 }
48 else
49 {
Finn Williamsb9dcfe62020-09-17 15:58:31 +010050 weight = decodedWeights[channelInput * outputSize + channelOutput];
telsoa014fcda012018-03-09 14:13:49 +000051 }
52
Finn Williamsb9dcfe62020-09-17 15:58:31 +010053 outval += weight * decodedInputs[n * K + channelInput];
telsoa014fcda012018-03-09 14:13:49 +000054 }
55
Francis Murtagh43aec582019-05-27 12:14:10 +010056 if (biasEnabled)
telsoa014fcda012018-03-09 14:13:49 +000057 {
Finn Williamsb9dcfe62020-09-17 15:58:31 +010058 outval += decodedBiases[channelOutput];
telsoa014fcda012018-03-09 14:13:49 +000059 }
60
Francis Murtagh43aec582019-05-27 12:14:10 +010061 rOutputEncoder[n * outputSize + channelOutput];
62 rOutputEncoder.Set(outval);
telsoa014fcda012018-03-09 14:13:49 +000063 }
64 }
65}
66
67} //namespace armnn