blob: 31c6262c9afbfb90c639f3894e533dfa7e90200a [file] [log] [blame]
narpra011e4c31d2018-09-28 11:07:51 +01001//
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
narpra011e4c31d2018-09-28 11:07:51 +01003// SPDX-License-Identifier: MIT
4//
5
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00006#include "Reduce.hpp"
narpra011e4c31d2018-09-28 11:07:51 +01007
Matthew Sloyan171214c2020-09-09 09:07:37 +01008#include <armnn/utility/NumericCast.hpp>
9
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +000010#include <backendsCommon/WorkloadData.hpp>
11
narpra011e4c31d2018-09-28 11:07:51 +010012#include <cmath>
13#include <cstddef>
14#include <functional>
15#include <limits>
16
17namespace armnn
18{
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +000019
narpra011e4c31d2018-09-28 11:07:51 +010020bool NextIndex(const unsigned int numDims, const armnn::TensorShape& dims, std::vector<unsigned int>& current)
21{
22 unsigned int carry = 1;
23
24 for (unsigned int idx = numDims; idx-- > 0; )
25 {
26 unsigned int current_val = current[idx] + carry;
27 if (dims[idx] == current_val)
28 {
29 current[idx] = 0;
30 }
31 else
32 {
33 current[idx] = current_val;
34 carry = 0;
35 break;
36 }
37 }
38 return (carry == 0);
39}
40
James Conroy4d1ff582019-06-10 17:06:39 +010041unsigned int ReducedOutputOffset(const unsigned int numDims,
42 const armnn::TensorShape& dims,
43 std::vector<unsigned int>& index,
44 const unsigned int numAxis,
45 const std::vector<unsigned int>& axis)
46{
47 unsigned int offset = 0;
narpra011e4c31d2018-09-28 11:07:51 +010048 for (unsigned int idx = 0; idx < numDims; ++idx)
49 {
50 bool isAxis = false;
51 if (!axis.empty())
52 {
53 for (unsigned int axisIdx = 0; axisIdx < numAxis; ++axisIdx)
54 {
55 if (idx == axis[axisIdx])
56 {
57 isAxis = true;
58 break;
59 }
60 }
61 }
62 if (!isAxis)
63 {
James Conroy4d1ff582019-06-10 17:06:39 +010064 offset = offset * dims[idx] + index[idx];
narpra011e4c31d2018-09-28 11:07:51 +010065 }
66 }
67 return offset;
68}
narpra011e4c31d2018-09-28 11:07:51 +010069
narpra011e4c31d2018-09-28 11:07:51 +010070
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +000071void Reduce(const TensorInfo& inputInfo,
72 const TensorInfo& outputInfo,
73 Decoder<float>& input,
74 Encoder<float>& output,
75 const std::vector<uint32_t> axis,
76 const ReduceOperation reduceOperation)
77{
narpra011e4c31d2018-09-28 11:07:51 +010078 armnn::TensorShape inputDims = inputInfo.GetShape();
Sadik Armagana2747482021-02-09 10:28:54 +000079 unsigned int inputNumDims = inputInfo.GetNumDimensions();
80 unsigned int numOutputs = outputInfo.GetNumElements();
narpra011e4c31d2018-09-28 11:07:51 +010081
Sadik Armagana2747482021-02-09 10:28:54 +000082 // Initialise temp output
83 std::vector<float> tempOut(numOutputs);
84 if (reduceOperation == ReduceOperation::Max || reduceOperation == ReduceOperation::Min)
narpra011e4c31d2018-09-28 11:07:51 +010085 {
Sadik Armagana2747482021-02-09 10:28:54 +000086 for (unsigned int idx = 0; idx < numOutputs; ++idx)
87 {
88 input[idx];
89 tempOut[idx] = input.Get();
90 }
91 }
92 else
93 {
94 std::fill(tempOut.begin(), tempOut.end(), 0.0);
narpra011e4c31d2018-09-28 11:07:51 +010095 }
96
Sadik Armagana2747482021-02-09 10:28:54 +000097 // Initialise temp index
98 std::vector<unsigned int> tempIndex(inputNumDims, 0);
narpra011e4c31d2018-09-28 11:07:51 +010099
100 std::vector<unsigned int> resolvedAxis = axis;
101 if (resolvedAxis.empty())
102 {
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000103 for (unsigned int idx = 0; idx < inputNumDims; ++idx)
104 {
105 resolvedAxis.push_back(idx);
106 }
narpra011e4c31d2018-09-28 11:07:51 +0100107 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100108 auto numResolvedAxis = armnn::numeric_cast<unsigned int>(resolvedAxis.size());
narpra011e4c31d2018-09-28 11:07:51 +0100109
Sadik Armagana2747482021-02-09 10:28:54 +0000110 // Iterates through input_data and operates over the reduced axis
narpra011e4c31d2018-09-28 11:07:51 +0100111 for (bool hasNext = true; hasNext; hasNext = NextIndex(inputNumDims, inputDims, tempIndex))
112 {
James Conroy4d1ff582019-06-10 17:06:39 +0100113 unsigned int inputOffset = ReducedOutputOffset(inputNumDims, inputDims, tempIndex, 0, {});
114 unsigned int outputOffset = ReducedOutputOffset(inputNumDims, inputDims, tempIndex,
115 numResolvedAxis, resolvedAxis);
116 input[inputOffset];
Sadik Armagana2747482021-02-09 10:28:54 +0000117 auto inputValue = input.Get();
118 if (reduceOperation == ReduceOperation::Max)
119 {
120 if (inputValue > tempOut[outputOffset])
121 {
122 tempOut[outputOffset] = inputValue;
123 }
124 }
125 else if (reduceOperation == ReduceOperation::Min)
126 {
127 if (inputValue < tempOut[outputOffset])
128 {
129 tempOut[outputOffset] = inputValue;
130 }
131 }
132 else
133 {
134 tempOut[outputOffset] += inputValue;
135 }
narpra011e4c31d2018-09-28 11:07:51 +0100136 }
137
Sadik Armagana2747482021-02-09 10:28:54 +0000138 // Takes average by num of elements added to get MEAN
narpra011e4c31d2018-09-28 11:07:51 +0100139 size_t numElementsInAxis = 1;
140 for (unsigned int idx = 0; idx < numResolvedAxis; ++idx)
141 {
James Conroy4d1ff582019-06-10 17:06:39 +0100142 unsigned int current = inputDims[resolvedAxis[idx]];
Matthew Sloyan24ac8592020-09-23 16:57:23 +0100143 ARMNN_ASSERT(armnn::numeric_cast<float>(current) <
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000144 (std::numeric_limits<float>::max() / armnn::numeric_cast<float>(numElementsInAxis)));
narpra011e4c31d2018-09-28 11:07:51 +0100145 numElementsInAxis *= current;
146 }
Sadik Armagana2747482021-02-09 10:28:54 +0000147
148 for (unsigned int idx = 0; idx < numOutputs; ++idx)
149 {
150 output[idx];
151 if (reduceOperation == ReduceOperation::Mean)
narpra011e4c31d2018-09-28 11:07:51 +0100152 {
Sadik Armagana2747482021-02-09 10:28:54 +0000153 if (numElementsInAxis > 0)
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000154 {
Sadik Armagana2747482021-02-09 10:28:54 +0000155 output.Set(tempOut[idx] / armnn::numeric_cast<float>(numElementsInAxis));
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000156 }
Sadik Armagana2747482021-02-09 10:28:54 +0000157 }
158 else
159 {
160 output.Set(tempOut[idx]);
narpra011e4c31d2018-09-28 11:07:51 +0100161 }
162 }
163}
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000164
165} //namespace armnn