blob: 6ea333b405770b6a7a21802f5b3fe364cbe99b8a [file] [log] [blame]
narpra011e4c31d2018-09-28 11:07:51 +01001//
Colm Donelanb4ef1632024-02-01 15:00:43 +00002// Copyright © 2021, 2024 Arm Ltd and Contributors. All rights reserved.
narpra011e4c31d2018-09-28 11:07:51 +01003// SPDX-License-Identifier: MIT
4//
5
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00006#include "Reduce.hpp"
narpra011e4c31d2018-09-28 11:07:51 +01007
Matthew Sloyan171214c2020-09-09 09:07:37 +01008#include <armnn/utility/NumericCast.hpp>
9
Colm Donelan0c479742021-12-10 12:43:54 +000010#include <armnn/backends/WorkloadData.hpp>
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +000011
narpra011e4c31d2018-09-28 11:07:51 +010012#include <cstddef>
13#include <functional>
14#include <limits>
15
16namespace armnn
17{
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +000018
narpra011e4c31d2018-09-28 11:07:51 +010019bool NextIndex(const unsigned int numDims, const armnn::TensorShape& dims, std::vector<unsigned int>& current)
20{
21 unsigned int carry = 1;
22
23 for (unsigned int idx = numDims; idx-- > 0; )
24 {
25 unsigned int current_val = current[idx] + carry;
26 if (dims[idx] == current_val)
27 {
28 current[idx] = 0;
29 }
30 else
31 {
32 current[idx] = current_val;
33 carry = 0;
34 break;
35 }
36 }
37 return (carry == 0);
38}
39
James Conroy4d1ff582019-06-10 17:06:39 +010040unsigned int ReducedOutputOffset(const unsigned int numDims,
41 const armnn::TensorShape& dims,
42 std::vector<unsigned int>& index,
43 const unsigned int numAxis,
44 const std::vector<unsigned int>& axis)
45{
46 unsigned int offset = 0;
narpra011e4c31d2018-09-28 11:07:51 +010047 for (unsigned int idx = 0; idx < numDims; ++idx)
48 {
49 bool isAxis = false;
50 if (!axis.empty())
51 {
52 for (unsigned int axisIdx = 0; axisIdx < numAxis; ++axisIdx)
53 {
54 if (idx == axis[axisIdx])
55 {
56 isAxis = true;
57 break;
58 }
59 }
60 }
61 if (!isAxis)
62 {
James Conroy4d1ff582019-06-10 17:06:39 +010063 offset = offset * dims[idx] + index[idx];
narpra011e4c31d2018-09-28 11:07:51 +010064 }
65 }
66 return offset;
67}
narpra011e4c31d2018-09-28 11:07:51 +010068
narpra011e4c31d2018-09-28 11:07:51 +010069
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +000070void Reduce(const TensorInfo& inputInfo,
71 const TensorInfo& outputInfo,
72 Decoder<float>& input,
73 Encoder<float>& output,
74 const std::vector<uint32_t> axis,
75 const ReduceOperation reduceOperation)
76{
narpra011e4c31d2018-09-28 11:07:51 +010077 armnn::TensorShape inputDims = inputInfo.GetShape();
Sadik Armagana2747482021-02-09 10:28:54 +000078 unsigned int inputNumDims = inputInfo.GetNumDimensions();
79 unsigned int numOutputs = outputInfo.GetNumElements();
narpra011e4c31d2018-09-28 11:07:51 +010080
Sadik Armagana2747482021-02-09 10:28:54 +000081 // Initialise temp output
82 std::vector<float> tempOut(numOutputs);
Teresa Charlin2226ca92021-02-11 23:05:40 +000083 switch(reduceOperation)
narpra011e4c31d2018-09-28 11:07:51 +010084 {
Teresa Charlin2226ca92021-02-11 23:05:40 +000085 case ReduceOperation::Mean:
86 case ReduceOperation::Sum:
Rob Hughesc013bc82021-07-14 09:31:31 +010087 std::fill(tempOut.begin(), tempOut.end(), 0.0f);
Teresa Charlin2226ca92021-02-11 23:05:40 +000088 break;
Teresa Charlin4e3e8312021-08-05 12:34:37 +010089 case ReduceOperation::Prod:
90 std::fill(tempOut.begin(), tempOut.end(), 1.0f);
91 break;
Teresa Charlin2226ca92021-02-11 23:05:40 +000092 case ReduceOperation::Max:
93 std::fill(tempOut.begin(), tempOut.end(), -1 * std::numeric_limits<float>::max());
94 break;
95 case ReduceOperation::Min:
96 std::fill(tempOut.begin(), tempOut.end(), std::numeric_limits<float>::max());
97 break;
98 default:
99 throw armnn::InvalidArgumentException("Unknown reduce method: " +
100 std::to_string(static_cast<int>(reduceOperation)));
narpra011e4c31d2018-09-28 11:07:51 +0100101 }
102
Sadik Armagana2747482021-02-09 10:28:54 +0000103 // Initialise temp index
104 std::vector<unsigned int> tempIndex(inputNumDims, 0);
narpra011e4c31d2018-09-28 11:07:51 +0100105
106 std::vector<unsigned int> resolvedAxis = axis;
107 if (resolvedAxis.empty())
108 {
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000109 for (unsigned int idx = 0; idx < inputNumDims; ++idx)
110 {
111 resolvedAxis.push_back(idx);
112 }
narpra011e4c31d2018-09-28 11:07:51 +0100113 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100114 auto numResolvedAxis = armnn::numeric_cast<unsigned int>(resolvedAxis.size());
narpra011e4c31d2018-09-28 11:07:51 +0100115
Sadik Armagana2747482021-02-09 10:28:54 +0000116 // Iterates through input_data and operates over the reduced axis
narpra011e4c31d2018-09-28 11:07:51 +0100117 for (bool hasNext = true; hasNext; hasNext = NextIndex(inputNumDims, inputDims, tempIndex))
118 {
James Conroy4d1ff582019-06-10 17:06:39 +0100119 unsigned int inputOffset = ReducedOutputOffset(inputNumDims, inputDims, tempIndex, 0, {});
120 unsigned int outputOffset = ReducedOutputOffset(inputNumDims, inputDims, tempIndex,
121 numResolvedAxis, resolvedAxis);
122 input[inputOffset];
Sadik Armagana2747482021-02-09 10:28:54 +0000123 auto inputValue = input.Get();
Teresa Charlin4e3e8312021-08-05 12:34:37 +0100124 switch(reduceOperation)
Sadik Armagana2747482021-02-09 10:28:54 +0000125 {
Teresa Charlin4e3e8312021-08-05 12:34:37 +0100126 case ReduceOperation::Mean:
127 case ReduceOperation::Sum:
128 tempOut[outputOffset] += inputValue;
129 break;
130 case ReduceOperation::Prod:
131 tempOut[outputOffset] *= inputValue;
132 break;
133 case ReduceOperation::Max:
134 if (inputValue > tempOut[outputOffset])
135 {
136 tempOut[outputOffset] = inputValue;
137 }
138 break;
139 case ReduceOperation::Min:
140 if (inputValue < tempOut[outputOffset])
141 {
142 tempOut[outputOffset] = inputValue;
143 }
144 break;
145 default:
146 throw armnn::InvalidArgumentException("Unknown reduce method: " +
147 std::to_string(static_cast<int>(reduceOperation)));
Sadik Armagana2747482021-02-09 10:28:54 +0000148 }
narpra011e4c31d2018-09-28 11:07:51 +0100149 }
150
Sadik Armagana2747482021-02-09 10:28:54 +0000151 // Takes average by num of elements added to get MEAN
narpra011e4c31d2018-09-28 11:07:51 +0100152 size_t numElementsInAxis = 1;
153 for (unsigned int idx = 0; idx < numResolvedAxis; ++idx)
154 {
James Conroy4d1ff582019-06-10 17:06:39 +0100155 unsigned int current = inputDims[resolvedAxis[idx]];
narpra011e4c31d2018-09-28 11:07:51 +0100156 numElementsInAxis *= current;
157 }
Sadik Armagana2747482021-02-09 10:28:54 +0000158
159 for (unsigned int idx = 0; idx < numOutputs; ++idx)
160 {
161 output[idx];
162 if (reduceOperation == ReduceOperation::Mean)
narpra011e4c31d2018-09-28 11:07:51 +0100163 {
Sadik Armagana2747482021-02-09 10:28:54 +0000164 if (numElementsInAxis > 0)
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000165 {
Sadik Armagana2747482021-02-09 10:28:54 +0000166 output.Set(tempOut[idx] / armnn::numeric_cast<float>(numElementsInAxis));
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000167 }
Sadik Armagana2747482021-02-09 10:28:54 +0000168 }
169 else
170 {
171 output.Set(tempOut[idx]);
narpra011e4c31d2018-09-28 11:07:51 +0100172 }
173 }
174}
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000175
176} //namespace armnn