blob: 02d9b060ef6e8a897b99170b40398748e7173073 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "FullyConnected.hpp"
7
Francis Murtagh43aec582019-05-27 12:14:10 +01008#include "RefWorkloadUtils.hpp"
9
telsoa014fcda012018-03-09 14:13:49 +000010#include <boost/assert.hpp>
11
12namespace armnn
13{
14
Francis Murtagh43aec582019-05-27 12:14:10 +010015void FullyConnected(const TensorShape& rInputShape,
16 Decoder<float>& rInputDecoder,
17 const TensorShape& rOutputShape,
18 Encoder<float>& rOutputEncoder,
19 Decoder<float>& rWeightDecoder,
20 Decoder<float>& rBiasDecoder,
21 const bool biasEnabled,
22 const unsigned int K,
23 const bool transposeWeights)
telsoa014fcda012018-03-09 14:13:49 +000024{
Francis Murtagh43aec582019-05-27 12:14:10 +010025 // Perform FullyConnected implementation
26 unsigned int outputSize = rOutputShape[1];
telsoa014fcda012018-03-09 14:13:49 +000027
Francis Murtagh43aec582019-05-27 12:14:10 +010028 for (unsigned int n = 0; n < rInputShape[0]; n++)
telsoa014fcda012018-03-09 14:13:49 +000029 {
Francis Murtagh43aec582019-05-27 12:14:10 +010030 for (unsigned int channelOutput = 0; channelOutput < outputSize; channelOutput++)
telsoa014fcda012018-03-09 14:13:49 +000031 {
32 float outval = 0.f;
33
34 for (unsigned int channelInput = 0; channelInput < K; channelInput++)
35 {
36 float weight;
37 if (transposeWeights)
38 {
Francis Murtagh43aec582019-05-27 12:14:10 +010039 rWeightDecoder[channelOutput * K + channelInput];
40 weight = rWeightDecoder.Get();
telsoa014fcda012018-03-09 14:13:49 +000041 }
42 else
43 {
Francis Murtagh43aec582019-05-27 12:14:10 +010044 rWeightDecoder[channelInput * outputSize + channelOutput];
45 weight = rWeightDecoder.Get();
telsoa014fcda012018-03-09 14:13:49 +000046 }
47
Francis Murtagh43aec582019-05-27 12:14:10 +010048 rInputDecoder[n * K + channelInput];
49 outval += weight * rInputDecoder.Get();
telsoa014fcda012018-03-09 14:13:49 +000050 }
51
Francis Murtagh43aec582019-05-27 12:14:10 +010052 if (biasEnabled)
telsoa014fcda012018-03-09 14:13:49 +000053 {
Francis Murtagh43aec582019-05-27 12:14:10 +010054 rBiasDecoder[channelOutput];
55 outval += rBiasDecoder.Get();
telsoa014fcda012018-03-09 14:13:49 +000056 }
57
Francis Murtagh43aec582019-05-27 12:14:10 +010058 rOutputEncoder[n * outputSize + channelOutput];
59 rOutputEncoder.Set(outval);
telsoa014fcda012018-03-09 14:13:49 +000060 }
61 }
62}
63
64} //namespace armnn