blob: 092689448903540fcfe020e7ee9057ad9ae5785a [file] [log] [blame]
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001//
Colm Donelanb4ef1632024-02-01 15:00:43 +00002// Copyright © 2019, 2024 Arm Ltd. All rights reserved.
Aron Virginas-Tare662a942019-10-14 15:12:00 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "LogSoftmax.hpp"
7
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/TensorUtils.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +01009#include <armnn/utility/NumericCast.hpp>
Aron Virginas-Tare662a942019-10-14 15:12:00 +010010
11#include <cmath>
12
Aron Virginas-Tare662a942019-10-14 15:12:00 +010013namespace
14{
15
16inline bool ValidateAxis(int axis, unsigned int numDimensions)
17{
Matthew Sloyan171214c2020-09-09 09:07:37 +010018 const int sNumDimensions = armnn::numeric_cast<int>(numDimensions);
Aron Virginas-Tare662a942019-10-14 15:12:00 +010019 return axis < sNumDimensions && axis >= -sNumDimensions;
20}
21
22} // anonymous namespace
23
24namespace armnn
25{
26
27void LogSoftmax(Decoder<float>& input,
28 Encoder<float>& output,
29 const TensorInfo& inputInfo,
30 const LogSoftmaxDescriptor& descriptor)
31{
32 const unsigned int numDimensions = inputInfo.GetNumDimensions();
33
Colm Donelanb4ef1632024-02-01 15:00:43 +000034 ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(ValidateAxis(descriptor.m_Axis, numDimensions),
35 "Axis index is not in range [-numDimensions, numDimensions).");
Aron Virginas-Tare662a942019-10-14 15:12:00 +010036
37 unsigned int uAxis = descriptor.m_Axis < 0 ?
Matthew Sloyan171214c2020-09-09 09:07:37 +010038 numDimensions - armnn::numeric_cast<unsigned int>(std::abs(descriptor.m_Axis)) :
39 armnn::numeric_cast<unsigned int>(descriptor.m_Axis);
Aron Virginas-Tare662a942019-10-14 15:12:00 +010040
41 const TensorShape& inputShape = inputInfo.GetShape();
42 const unsigned int outerSize = armnnUtils::GetNumElementsBetween(inputShape, 0, uAxis);
43 const unsigned int axisSize = inputShape[uAxis];
44 const unsigned int innerSize = armnnUtils::GetNumElementsBetween(inputShape,
45 uAxis + 1,
46 inputShape.GetNumDimensions());
47
48 for (unsigned int outer = 0; outer < outerSize; ++outer)
49 {
50 for (unsigned int inner = 0; inner < innerSize; ++inner)
51 {
52 // Find max
53 input[outer * axisSize * innerSize + inner];
54 float maxValue = input.Get();
55 for (unsigned int i = 1u; i < axisSize; ++i)
56 {
57 input[(outer * axisSize + i) * innerSize + inner];
58 maxValue = std::max(maxValue, input.Get());
59 }
60
61 // Compute sum
62 float sum = 0.0f;
63 for (unsigned int i = 0u; i < axisSize; ++i)
64 {
65 input[(outer * axisSize + i) * innerSize + inner];
66 sum += std::exp((input.Get() - maxValue) * descriptor.m_Beta);
67 }
68
69 // Compute log sum
70 const float logSum = std::log(sum);
71
72 // Compute result
73 for (unsigned int i = 0u; i < axisSize; ++i)
74 {
75 const unsigned int index = (outer * axisSize + i) * innerSize + inner;
76
77 input [index];
78 output[index];
79
80 output.Set((input.Get() - maxValue) * descriptor.m_Beta - logSum);
81 }
82 }
83 }
84}
85
86} // namespace armnn