blob: 2b6384913ef877d524a1e4abb9c59bb9fda735b3 [file] [log] [blame]
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "LogSoftmax.hpp"
7
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/TensorUtils.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01009#include <armnn/utility/Assert.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Aron Virginas-Tare662a942019-10-14 15:12:00 +010012
13#include <cmath>
14
Aron Virginas-Tare662a942019-10-14 15:12:00 +010015namespace
16{
17
18inline bool ValidateAxis(int axis, unsigned int numDimensions)
19{
Matthew Sloyan171214c2020-09-09 09:07:37 +010020 const int sNumDimensions = armnn::numeric_cast<int>(numDimensions);
Aron Virginas-Tare662a942019-10-14 15:12:00 +010021 return axis < sNumDimensions && axis >= -sNumDimensions;
22}
23
24} // anonymous namespace
25
26namespace armnn
27{
28
29void LogSoftmax(Decoder<float>& input,
30 Encoder<float>& output,
31 const TensorInfo& inputInfo,
32 const LogSoftmaxDescriptor& descriptor)
33{
34 const unsigned int numDimensions = inputInfo.GetNumDimensions();
35
36 bool axisIsValid = ValidateAxis(descriptor.m_Axis, numDimensions);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010037 ARMNN_ASSERT_MSG(axisIsValid,
Aron Virginas-Tare662a942019-10-14 15:12:00 +010038 "Axis index is not in range [-numDimensions, numDimensions).");
Jan Eilers8eb25602020-03-09 12:13:48 +000039 IgnoreUnused(axisIsValid);
Aron Virginas-Tare662a942019-10-14 15:12:00 +010040
41 unsigned int uAxis = descriptor.m_Axis < 0 ?
Matthew Sloyan171214c2020-09-09 09:07:37 +010042 numDimensions - armnn::numeric_cast<unsigned int>(std::abs(descriptor.m_Axis)) :
43 armnn::numeric_cast<unsigned int>(descriptor.m_Axis);
Aron Virginas-Tare662a942019-10-14 15:12:00 +010044
45 const TensorShape& inputShape = inputInfo.GetShape();
46 const unsigned int outerSize = armnnUtils::GetNumElementsBetween(inputShape, 0, uAxis);
47 const unsigned int axisSize = inputShape[uAxis];
48 const unsigned int innerSize = armnnUtils::GetNumElementsBetween(inputShape,
49 uAxis + 1,
50 inputShape.GetNumDimensions());
51
52 for (unsigned int outer = 0; outer < outerSize; ++outer)
53 {
54 for (unsigned int inner = 0; inner < innerSize; ++inner)
55 {
56 // Find max
57 input[outer * axisSize * innerSize + inner];
58 float maxValue = input.Get();
59 for (unsigned int i = 1u; i < axisSize; ++i)
60 {
61 input[(outer * axisSize + i) * innerSize + inner];
62 maxValue = std::max(maxValue, input.Get());
63 }
64
65 // Compute sum
66 float sum = 0.0f;
67 for (unsigned int i = 0u; i < axisSize; ++i)
68 {
69 input[(outer * axisSize + i) * innerSize + inner];
70 sum += std::exp((input.Get() - maxValue) * descriptor.m_Beta);
71 }
72
73 // Compute log sum
74 const float logSum = std::log(sum);
75
76 // Compute result
77 for (unsigned int i = 0u; i < axisSize; ++i)
78 {
79 const unsigned int index = (outer * axisSize + i) * innerSize + inner;
80
81 input [index];
82 output[index];
83
84 output.Set((input.Get() - maxValue) * descriptor.m_Beta - logSum);
85 }
86 }
87 }
88}
89
90} // namespace armnn