blob: ec4fdb8839fec8803b2aa93cb7a9fd0d477caa3f [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "Softmax.hpp"
7
8#include <cmath>
9#include <vector>
10
11namespace armnn
12{
13
Francis Murtagh07f21212019-07-23 09:50:50 +010014unsigned int GetNumElementsBetween(const TensorShape& shape,
15 unsigned int firstAxisInclusive,
16 unsigned int lastAxisExclusive)
telsoa014fcda012018-03-09 14:13:49 +000017{
Francis Murtagh07f21212019-07-23 09:50:50 +010018 BOOST_ASSERT(0 <= firstAxisInclusive);
19 BOOST_ASSERT(firstAxisInclusive <= lastAxisExclusive);
20 BOOST_ASSERT(lastAxisExclusive <= shape.GetNumDimensions());
21 unsigned int count = 1;
22 for (unsigned int i = firstAxisInclusive; i < lastAxisExclusive; i++)
telsoa014fcda012018-03-09 14:13:49 +000023 {
Francis Murtagh07f21212019-07-23 09:50:50 +010024 count *= shape[i];
25 }
26 return count;
27}
28
29/// Computes the softmax function on some inputs, into outputs, with a shape given by tensorInfo.
30void Softmax(Decoder<float>& in, Encoder<float>& out, const TensorInfo& inputTensorInfo, float beta, int axis)
31{
32 BOOST_ASSERT_MSG(axis < static_cast<int>(inputTensorInfo.GetNumDimensions()),
33 "Required axis index greater than number of dimensions.");
34 BOOST_ASSERT_MSG(axis >= -static_cast<int>(inputTensorInfo.GetNumDimensions()),
35 "Required axis index lower than negative of the number of dimensions");
36
37 unsigned int uAxis = axis < 0 ?
38 inputTensorInfo.GetNumDimensions() - static_cast<unsigned int>(abs(axis))
39 : static_cast<unsigned int>(axis);
40
41 const TensorShape& inputShape = inputTensorInfo.GetShape();
42 const unsigned int outerSize = GetNumElementsBetween(inputShape, 0, uAxis);
43 const unsigned int axisSize = inputShape[uAxis];
44 const unsigned int innerSize = GetNumElementsBetween(inputShape, uAxis + 1, inputShape.GetNumDimensions());
45
46 for (unsigned int outer = 0; outer < outerSize; ++outer)
47 {
48 unsigned int inputBeginIdx = outer * axisSize * innerSize;
49 unsigned int inputEndIdx = inputBeginIdx + axisSize * innerSize;
50 unsigned int outputBeginIdx = outer * axisSize * innerSize;
51
52 for (unsigned int inner = 0; inner < innerSize; ++inner, ++inputBeginIdx, ++inputEndIdx, ++outputBeginIdx)
telsoa014fcda012018-03-09 14:13:49 +000053 {
Francis Murtagh07f21212019-07-23 09:50:50 +010054 // Find max
55 float maxValue = std::numeric_limits<float>::lowest();
56 for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize)
telsoa014fcda012018-03-09 14:13:49 +000057 {
Francis Murtagh07f21212019-07-23 09:50:50 +010058 in[iter];
59 maxValue = std::max(maxValue, in.Get());
telsoa014fcda012018-03-09 14:13:49 +000060 }
telsoa014fcda012018-03-09 14:13:49 +000061
Francis Murtagh07f21212019-07-23 09:50:50 +010062 // Compute sum
63 float sum = 0.0f;
64 for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize)
65 {
66 in[iter];
67 sum += std::exp((in.Get() - maxValue) * beta);
68 }
telsoa014fcda012018-03-09 14:13:49 +000069
Francis Murtagh07f21212019-07-23 09:50:50 +010070 // Compute result
71 unsigned int outputIter = outputBeginIdx;
72 out[outputIter];
73 for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize, outputIter += innerSize)
74 {
75 out[outputIter];
76 in[iter];
77 out.Set(std::exp((in.Get() - maxValue) * beta) / sum);
78 }
telsoa014fcda012018-03-09 14:13:49 +000079 }
80 }
81}
82
83} //namespace armnn