blob: 47968f4d88cc360a61b6205e6f6b4b9ad0f1ca3d [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "FullyConnected.hpp"
7
Matthew Bentham18bf43d2021-07-07 09:08:48 +01008#include <armnn/utility/Assert.hpp>
9
Francis Murtagh43aec582019-05-27 12:14:10 +010010#include "RefWorkloadUtils.hpp"
11
telsoa014fcda012018-03-09 14:13:49 +000012namespace armnn
13{
14
Francis Murtagh43aec582019-05-27 12:14:10 +010015void FullyConnected(const TensorShape& rInputShape,
16 Decoder<float>& rInputDecoder,
17 const TensorShape& rOutputShape,
18 Encoder<float>& rOutputEncoder,
Finn Williamsb9dcfe62020-09-17 15:58:31 +010019 const TensorShape& rWeightsShape,
Francis Murtagh43aec582019-05-27 12:14:10 +010020 Decoder<float>& rWeightDecoder,
Matthew Bentham18bf43d2021-07-07 09:08:48 +010021 Decoder<float>* pBiasDecoder,
Francis Murtagh43aec582019-05-27 12:14:10 +010022 const bool biasEnabled,
23 const unsigned int K,
24 const bool transposeWeights)
telsoa014fcda012018-03-09 14:13:49 +000025{
Francis Murtagh43aec582019-05-27 12:14:10 +010026 // Perform FullyConnected implementation
27 unsigned int outputSize = rOutputShape[1];
telsoa014fcda012018-03-09 14:13:49 +000028
Finn Williamsea8ce702020-09-29 19:54:00 +010029 const std::vector<float> decodedInputs = rInputDecoder.DecodeTensor(rInputShape);
30 const std::vector<float> decodedWeights = rWeightDecoder.DecodeTensor(rWeightsShape);
31
32 const TensorShape biasShape{outputSize};
Matthew Bentham18bf43d2021-07-07 09:08:48 +010033
34 ARMNN_ASSERT(!biasEnabled || pBiasDecoder != nullptr);
35 const std::vector<float> decodedBiases = biasEnabled ? pBiasDecoder->DecodeTensor(biasShape) : std::vector<float>();
Finn Williamsb9dcfe62020-09-17 15:58:31 +010036
37
Francis Murtagh43aec582019-05-27 12:14:10 +010038 for (unsigned int n = 0; n < rInputShape[0]; n++)
telsoa014fcda012018-03-09 14:13:49 +000039 {
Francis Murtagh43aec582019-05-27 12:14:10 +010040 for (unsigned int channelOutput = 0; channelOutput < outputSize; channelOutput++)
telsoa014fcda012018-03-09 14:13:49 +000041 {
42 float outval = 0.f;
43
44 for (unsigned int channelInput = 0; channelInput < K; channelInput++)
45 {
46 float weight;
47 if (transposeWeights)
48 {
Finn Williamsb9dcfe62020-09-17 15:58:31 +010049 weight = decodedWeights[channelOutput * K + channelInput];
telsoa014fcda012018-03-09 14:13:49 +000050 }
51 else
52 {
Finn Williamsb9dcfe62020-09-17 15:58:31 +010053 weight = decodedWeights[channelInput * outputSize + channelOutput];
telsoa014fcda012018-03-09 14:13:49 +000054 }
55
Finn Williamsb9dcfe62020-09-17 15:58:31 +010056 outval += weight * decodedInputs[n * K + channelInput];
telsoa014fcda012018-03-09 14:13:49 +000057 }
58
Francis Murtagh43aec582019-05-27 12:14:10 +010059 if (biasEnabled)
telsoa014fcda012018-03-09 14:13:49 +000060 {
Finn Williamsb9dcfe62020-09-17 15:58:31 +010061 outval += decodedBiases[channelOutput];
telsoa014fcda012018-03-09 14:13:49 +000062 }
63
Francis Murtagh43aec582019-05-27 12:14:10 +010064 rOutputEncoder[n * outputSize + channelOutput];
65 rOutputEncoder.Set(outval);
telsoa014fcda012018-03-09 14:13:49 +000066 }
67 }
68}
69
70} //namespace armnn