blob: 103d62a8df16674fe4099f38808c5d63caab5239 [file] [log] [blame]
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "LogSoftmax.hpp"
7
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/TensorUtils.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +00009#include <armnn/utility/IgnoreUnused.hpp>
Aron Virginas-Tare662a942019-10-14 15:12:00 +010010
11#include <cmath>
12
13#include <boost/assert.hpp>
Aron Virginas-Tare662a942019-10-14 15:12:00 +010014#include <boost/numeric/conversion/cast.hpp>
15
16namespace
17{
18
19inline bool ValidateAxis(int axis, unsigned int numDimensions)
20{
21 const int sNumDimensions = boost::numeric_cast<int>(numDimensions);
22 return axis < sNumDimensions && axis >= -sNumDimensions;
23}
24
25} // anonymous namespace
26
27namespace armnn
28{
29
30void LogSoftmax(Decoder<float>& input,
31 Encoder<float>& output,
32 const TensorInfo& inputInfo,
33 const LogSoftmaxDescriptor& descriptor)
34{
35 const unsigned int numDimensions = inputInfo.GetNumDimensions();
36
37 bool axisIsValid = ValidateAxis(descriptor.m_Axis, numDimensions);
38 BOOST_ASSERT_MSG(axisIsValid,
39 "Axis index is not in range [-numDimensions, numDimensions).");
Jan Eilers8eb25602020-03-09 12:13:48 +000040 IgnoreUnused(axisIsValid);
Aron Virginas-Tare662a942019-10-14 15:12:00 +010041
42 unsigned int uAxis = descriptor.m_Axis < 0 ?
43 numDimensions - boost::numeric_cast<unsigned int>(std::abs(descriptor.m_Axis)) :
44 boost::numeric_cast<unsigned int>(descriptor.m_Axis);
45
46 const TensorShape& inputShape = inputInfo.GetShape();
47 const unsigned int outerSize = armnnUtils::GetNumElementsBetween(inputShape, 0, uAxis);
48 const unsigned int axisSize = inputShape[uAxis];
49 const unsigned int innerSize = armnnUtils::GetNumElementsBetween(inputShape,
50 uAxis + 1,
51 inputShape.GetNumDimensions());
52
53 for (unsigned int outer = 0; outer < outerSize; ++outer)
54 {
55 for (unsigned int inner = 0; inner < innerSize; ++inner)
56 {
57 // Find max
58 input[outer * axisSize * innerSize + inner];
59 float maxValue = input.Get();
60 for (unsigned int i = 1u; i < axisSize; ++i)
61 {
62 input[(outer * axisSize + i) * innerSize + inner];
63 maxValue = std::max(maxValue, input.Get());
64 }
65
66 // Compute sum
67 float sum = 0.0f;
68 for (unsigned int i = 0u; i < axisSize; ++i)
69 {
70 input[(outer * axisSize + i) * innerSize + inner];
71 sum += std::exp((input.Get() - maxValue) * descriptor.m_Beta);
72 }
73
74 // Compute log sum
75 const float logSum = std::log(sum);
76
77 // Compute result
78 for (unsigned int i = 0u; i < axisSize; ++i)
79 {
80 const unsigned int index = (outer * axisSize + i) * innerSize + inner;
81
82 input [index];
83 output[index];
84
85 output.Set((input.Get() - maxValue) * descriptor.m_Beta - logSum);
86 }
87 }
88 }
89}
90
91} // namespace armnn