blob: 8016c1b628570238938868bdbd5fc3951da144a0 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "FullyConnected.hpp"
7
Francis Murtagh43aec582019-05-27 12:14:10 +01008#include "RefWorkloadUtils.hpp"
9
telsoa014fcda012018-03-09 14:13:49 +000010namespace armnn
11{
12
Francis Murtagh43aec582019-05-27 12:14:10 +010013void FullyConnected(const TensorShape& rInputShape,
14 Decoder<float>& rInputDecoder,
15 const TensorShape& rOutputShape,
16 Encoder<float>& rOutputEncoder,
17 Decoder<float>& rWeightDecoder,
18 Decoder<float>& rBiasDecoder,
19 const bool biasEnabled,
20 const unsigned int K,
21 const bool transposeWeights)
telsoa014fcda012018-03-09 14:13:49 +000022{
Francis Murtagh43aec582019-05-27 12:14:10 +010023 // Perform FullyConnected implementation
24 unsigned int outputSize = rOutputShape[1];
telsoa014fcda012018-03-09 14:13:49 +000025
Francis Murtagh43aec582019-05-27 12:14:10 +010026 for (unsigned int n = 0; n < rInputShape[0]; n++)
telsoa014fcda012018-03-09 14:13:49 +000027 {
Francis Murtagh43aec582019-05-27 12:14:10 +010028 for (unsigned int channelOutput = 0; channelOutput < outputSize; channelOutput++)
telsoa014fcda012018-03-09 14:13:49 +000029 {
30 float outval = 0.f;
31
32 for (unsigned int channelInput = 0; channelInput < K; channelInput++)
33 {
34 float weight;
35 if (transposeWeights)
36 {
Francis Murtagh43aec582019-05-27 12:14:10 +010037 rWeightDecoder[channelOutput * K + channelInput];
38 weight = rWeightDecoder.Get();
telsoa014fcda012018-03-09 14:13:49 +000039 }
40 else
41 {
Francis Murtagh43aec582019-05-27 12:14:10 +010042 rWeightDecoder[channelInput * outputSize + channelOutput];
43 weight = rWeightDecoder.Get();
telsoa014fcda012018-03-09 14:13:49 +000044 }
45
Francis Murtagh43aec582019-05-27 12:14:10 +010046 rInputDecoder[n * K + channelInput];
47 outval += weight * rInputDecoder.Get();
telsoa014fcda012018-03-09 14:13:49 +000048 }
49
Francis Murtagh43aec582019-05-27 12:14:10 +010050 if (biasEnabled)
telsoa014fcda012018-03-09 14:13:49 +000051 {
Francis Murtagh43aec582019-05-27 12:14:10 +010052 rBiasDecoder[channelOutput];
53 outval += rBiasDecoder.Get();
telsoa014fcda012018-03-09 14:13:49 +000054 }
55
Francis Murtagh43aec582019-05-27 12:14:10 +010056 rOutputEncoder[n * outputSize + channelOutput];
57 rOutputEncoder.Set(outval);
telsoa014fcda012018-03-09 14:13:49 +000058 }
59 }
60}
61
62} //namespace armnn