blob: 00d496db8579430defeb69205aeef3ae7c7331eb [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "Softmax.hpp"
7
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/TensorUtils.hpp>
Narumol Prangnawarat4dc64a62019-09-16 17:00:22 +01009
telsoa014fcda012018-03-09 14:13:49 +000010#include <cmath>
11#include <vector>
12
13namespace armnn
14{
15
Francis Murtagh07f21212019-07-23 09:50:50 +010016/// Computes the softmax function on some inputs, into outputs, with a shape given by tensorInfo.
17void Softmax(Decoder<float>& in, Encoder<float>& out, const TensorInfo& inputTensorInfo, float beta, int axis)
18{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010019 ARMNN_ASSERT_MSG(axis < static_cast<int>(inputTensorInfo.GetNumDimensions()),
Francis Murtagh07f21212019-07-23 09:50:50 +010020 "Required axis index greater than number of dimensions.");
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010021 ARMNN_ASSERT_MSG(axis >= -static_cast<int>(inputTensorInfo.GetNumDimensions()),
Francis Murtagh07f21212019-07-23 09:50:50 +010022 "Required axis index lower than negative of the number of dimensions");
23
24 unsigned int uAxis = axis < 0 ?
25 inputTensorInfo.GetNumDimensions() - static_cast<unsigned int>(abs(axis))
26 : static_cast<unsigned int>(axis);
27
28 const TensorShape& inputShape = inputTensorInfo.GetShape();
Narumol Prangnawarat4dc64a62019-09-16 17:00:22 +010029 const unsigned int outerSize = armnnUtils::GetNumElementsBetween(inputShape, 0, uAxis);
Francis Murtagh07f21212019-07-23 09:50:50 +010030 const unsigned int axisSize = inputShape[uAxis];
Narumol Prangnawarat4dc64a62019-09-16 17:00:22 +010031 const unsigned int innerSize = armnnUtils::GetNumElementsBetween(inputShape,
32 uAxis + 1,
33 inputShape.GetNumDimensions());
Francis Murtagh07f21212019-07-23 09:50:50 +010034
35 for (unsigned int outer = 0; outer < outerSize; ++outer)
36 {
37 unsigned int inputBeginIdx = outer * axisSize * innerSize;
38 unsigned int inputEndIdx = inputBeginIdx + axisSize * innerSize;
39 unsigned int outputBeginIdx = outer * axisSize * innerSize;
40
41 for (unsigned int inner = 0; inner < innerSize; ++inner, ++inputBeginIdx, ++inputEndIdx, ++outputBeginIdx)
telsoa014fcda012018-03-09 14:13:49 +000042 {
Francis Murtagh07f21212019-07-23 09:50:50 +010043 // Find max
44 float maxValue = std::numeric_limits<float>::lowest();
45 for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize)
telsoa014fcda012018-03-09 14:13:49 +000046 {
Francis Murtagh07f21212019-07-23 09:50:50 +010047 in[iter];
48 maxValue = std::max(maxValue, in.Get());
telsoa014fcda012018-03-09 14:13:49 +000049 }
telsoa014fcda012018-03-09 14:13:49 +000050
Francis Murtagh07f21212019-07-23 09:50:50 +010051 // Compute sum
52 float sum = 0.0f;
53 for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize)
54 {
55 in[iter];
56 sum += std::exp((in.Get() - maxValue) * beta);
57 }
telsoa014fcda012018-03-09 14:13:49 +000058
Francis Murtagh07f21212019-07-23 09:50:50 +010059 // Compute result
60 unsigned int outputIter = outputBeginIdx;
61 out[outputIter];
62 for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize, outputIter += innerSize)
63 {
64 out[outputIter];
65 in[iter];
66 out.Set(std::exp((in.Get() - maxValue) * beta) / sum);
67 }
telsoa014fcda012018-03-09 14:13:49 +000068 }
69 }
70}
71
72} //namespace armnn