blob: 9ec9ea6c6c30bd4782231df88216309d5fbbef0a [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "FullyConnected.hpp"
7
Francis Murtagh43aec582019-05-27 12:14:10 +01008#include "RefWorkloadUtils.hpp"
9
telsoa014fcda012018-03-09 14:13:49 +000010namespace armnn
11{
12
Francis Murtagh43aec582019-05-27 12:14:10 +010013void FullyConnected(const TensorShape& rInputShape,
14 Decoder<float>& rInputDecoder,
15 const TensorShape& rOutputShape,
16 Encoder<float>& rOutputEncoder,
Finn Williamsb9dcfe62020-09-17 15:58:31 +010017 const TensorShape& rWeightsShape,
Francis Murtagh43aec582019-05-27 12:14:10 +010018 Decoder<float>& rWeightDecoder,
19 Decoder<float>& rBiasDecoder,
20 const bool biasEnabled,
21 const unsigned int K,
22 const bool transposeWeights)
telsoa014fcda012018-03-09 14:13:49 +000023{
Francis Murtagh43aec582019-05-27 12:14:10 +010024 // Perform FullyConnected implementation
25 unsigned int outputSize = rOutputShape[1];
telsoa014fcda012018-03-09 14:13:49 +000026
Finn Williamsea8ce702020-09-29 19:54:00 +010027 const std::vector<float> decodedInputs = rInputDecoder.DecodeTensor(rInputShape);
28 const std::vector<float> decodedWeights = rWeightDecoder.DecodeTensor(rWeightsShape);
29
30 const TensorShape biasShape{outputSize};
31 const std::vector<float> decodedBiases = biasEnabled ? rBiasDecoder.DecodeTensor(biasShape) : std::vector<float>();
Finn Williamsb9dcfe62020-09-17 15:58:31 +010032
33
Francis Murtagh43aec582019-05-27 12:14:10 +010034 for (unsigned int n = 0; n < rInputShape[0]; n++)
telsoa014fcda012018-03-09 14:13:49 +000035 {
Francis Murtagh43aec582019-05-27 12:14:10 +010036 for (unsigned int channelOutput = 0; channelOutput < outputSize; channelOutput++)
telsoa014fcda012018-03-09 14:13:49 +000037 {
38 float outval = 0.f;
39
40 for (unsigned int channelInput = 0; channelInput < K; channelInput++)
41 {
42 float weight;
43 if (transposeWeights)
44 {
Finn Williamsb9dcfe62020-09-17 15:58:31 +010045 weight = decodedWeights[channelOutput * K + channelInput];
telsoa014fcda012018-03-09 14:13:49 +000046 }
47 else
48 {
Finn Williamsb9dcfe62020-09-17 15:58:31 +010049 weight = decodedWeights[channelInput * outputSize + channelOutput];
telsoa014fcda012018-03-09 14:13:49 +000050 }
51
Finn Williamsb9dcfe62020-09-17 15:58:31 +010052 outval += weight * decodedInputs[n * K + channelInput];
telsoa014fcda012018-03-09 14:13:49 +000053 }
54
Francis Murtagh43aec582019-05-27 12:14:10 +010055 if (biasEnabled)
telsoa014fcda012018-03-09 14:13:49 +000056 {
Finn Williamsb9dcfe62020-09-17 15:58:31 +010057 outval += decodedBiases[channelOutput];
telsoa014fcda012018-03-09 14:13:49 +000058 }
59
Francis Murtagh43aec582019-05-27 12:14:10 +010060 rOutputEncoder[n * outputSize + channelOutput];
61 rOutputEncoder.Set(outval);
telsoa014fcda012018-03-09 14:13:49 +000062 }
63 }
64}
65
66} //namespace armnn