blob: 6cb219a6cc88aafdd502f863d37bdc807d2f41b7 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "Softmax.hpp"
7
8#include <cmath>
9#include <vector>
10
11namespace armnn
12{
13
telsoa01c577f2c2018-08-31 09:22:23 +010014/// Computes the softmax function on some inputs, into outputs, with a shape given by tensorInfo.
nikraj01a121de32019-05-29 10:51:05 +010015void Softmax(Decoder<float>& in, Encoder<float>& out, const TensorInfo& inputTensorInfo, float beta)
telsoa014fcda012018-03-09 14:13:49 +000016{
nikraj01a121de32019-05-29 10:51:05 +010017 unsigned int numChannels = inputTensorInfo.GetShape()[1];
18
19 for (unsigned int n = 0; n < inputTensorInfo.GetShape()[0]; n++)
telsoa014fcda012018-03-09 14:13:49 +000020 {
telsoa01c577f2c2018-08-31 09:22:23 +010021 // Find maximum channel.
nikraj01a121de32019-05-29 10:51:05 +010022 in[n * numChannels];
23 float max = in.Get();
telsoa014fcda012018-03-09 14:13:49 +000024 for (unsigned int c = 1; c < numChannels; c++)
25 {
nikraj01a121de32019-05-29 10:51:05 +010026 in[n * numChannels + c];
27 float val = in.Get();
telsoa014fcda012018-03-09 14:13:49 +000028 if (val > max)
29 {
30 max = val;
31 }
32 }
33
telsoa01c577f2c2018-08-31 09:22:23 +010034 // Exponentiate all values and sum.
telsoa014fcda012018-03-09 14:13:49 +000035 std::vector<float> exponentials(numChannels);
36 float sum = 0.0f;
37 for (unsigned int c = 0; c < numChannels; c++)
38 {
nikraj01a121de32019-05-29 10:51:05 +010039 in[n * numChannels + c];
40 float val = in.Get();
telsoa014fcda012018-03-09 14:13:49 +000041 exponentials[c] = expf((val - max) * beta);
42 sum += exponentials[c];
43 }
44
telsoa01c577f2c2018-08-31 09:22:23 +010045 // Divide exponentials by sum to give outputs.
telsoa014fcda012018-03-09 14:13:49 +000046 for (unsigned int c = 0; c < numChannels; c++)
47 {
nikraj01a121de32019-05-29 10:51:05 +010048 out[n * numChannels + c];
49 out.Set(exponentials[c] / sum);
telsoa014fcda012018-03-09 14:13:49 +000050 }
51 }
52}
53
54} //namespace armnn