blob: 392ef8e5baff64f8b392bf12bf98bb24ba3f4de9 [file] [log] [blame]
narpra011e4c31d2018-09-28 11:07:51 +01001//
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
narpra011e4c31d2018-09-28 11:07:51 +01003// SPDX-License-Identifier: MIT
4//
5
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00006#include "Reduce.hpp"
narpra011e4c31d2018-09-28 11:07:51 +01007
Matthew Sloyan171214c2020-09-09 09:07:37 +01008#include <armnn/utility/NumericCast.hpp>
9
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +000010#include <backendsCommon/WorkloadData.hpp>
11
narpra011e4c31d2018-09-28 11:07:51 +010012#include <cmath>
13#include <cstddef>
14#include <functional>
15#include <limits>
16
17namespace armnn
18{
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +000019
narpra011e4c31d2018-09-28 11:07:51 +010020bool NextIndex(const unsigned int numDims, const armnn::TensorShape& dims, std::vector<unsigned int>& current)
21{
22 unsigned int carry = 1;
23
24 for (unsigned int idx = numDims; idx-- > 0; )
25 {
26 unsigned int current_val = current[idx] + carry;
27 if (dims[idx] == current_val)
28 {
29 current[idx] = 0;
30 }
31 else
32 {
33 current[idx] = current_val;
34 carry = 0;
35 break;
36 }
37 }
38 return (carry == 0);
39}
40
James Conroy4d1ff582019-06-10 17:06:39 +010041unsigned int ReducedOutputOffset(const unsigned int numDims,
42 const armnn::TensorShape& dims,
43 std::vector<unsigned int>& index,
44 const unsigned int numAxis,
45 const std::vector<unsigned int>& axis)
46{
47 unsigned int offset = 0;
narpra011e4c31d2018-09-28 11:07:51 +010048 for (unsigned int idx = 0; idx < numDims; ++idx)
49 {
50 bool isAxis = false;
51 if (!axis.empty())
52 {
53 for (unsigned int axisIdx = 0; axisIdx < numAxis; ++axisIdx)
54 {
55 if (idx == axis[axisIdx])
56 {
57 isAxis = true;
58 break;
59 }
60 }
61 }
62 if (!isAxis)
63 {
James Conroy4d1ff582019-06-10 17:06:39 +010064 offset = offset * dims[idx] + index[idx];
narpra011e4c31d2018-09-28 11:07:51 +010065 }
66 }
67 return offset;
68}
narpra011e4c31d2018-09-28 11:07:51 +010069
narpra011e4c31d2018-09-28 11:07:51 +010070
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +000071void Reduce(const TensorInfo& inputInfo,
72 const TensorInfo& outputInfo,
73 Decoder<float>& input,
74 Encoder<float>& output,
75 const std::vector<uint32_t> axis,
76 const ReduceOperation reduceOperation)
77{
narpra011e4c31d2018-09-28 11:07:51 +010078 armnn::TensorShape inputDims = inputInfo.GetShape();
Sadik Armagana2747482021-02-09 10:28:54 +000079 unsigned int inputNumDims = inputInfo.GetNumDimensions();
80 unsigned int numOutputs = outputInfo.GetNumElements();
narpra011e4c31d2018-09-28 11:07:51 +010081
Sadik Armagana2747482021-02-09 10:28:54 +000082 // Initialise temp output
83 std::vector<float> tempOut(numOutputs);
Teresa Charlin2226ca92021-02-11 23:05:40 +000084 switch(reduceOperation)
narpra011e4c31d2018-09-28 11:07:51 +010085 {
Teresa Charlin2226ca92021-02-11 23:05:40 +000086 case ReduceOperation::Mean:
87 case ReduceOperation::Sum:
88 std::fill(tempOut.begin(), tempOut.end(), 0.0);
89 break;
90 case ReduceOperation::Max:
91 std::fill(tempOut.begin(), tempOut.end(), -1 * std::numeric_limits<float>::max());
92 break;
93 case ReduceOperation::Min:
94 std::fill(tempOut.begin(), tempOut.end(), std::numeric_limits<float>::max());
95 break;
96 default:
97 throw armnn::InvalidArgumentException("Unknown reduce method: " +
98 std::to_string(static_cast<int>(reduceOperation)));
narpra011e4c31d2018-09-28 11:07:51 +010099 }
100
Sadik Armagana2747482021-02-09 10:28:54 +0000101 // Initialise temp index
102 std::vector<unsigned int> tempIndex(inputNumDims, 0);
narpra011e4c31d2018-09-28 11:07:51 +0100103
104 std::vector<unsigned int> resolvedAxis = axis;
105 if (resolvedAxis.empty())
106 {
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000107 for (unsigned int idx = 0; idx < inputNumDims; ++idx)
108 {
109 resolvedAxis.push_back(idx);
110 }
narpra011e4c31d2018-09-28 11:07:51 +0100111 }
Matthew Sloyan171214c2020-09-09 09:07:37 +0100112 auto numResolvedAxis = armnn::numeric_cast<unsigned int>(resolvedAxis.size());
narpra011e4c31d2018-09-28 11:07:51 +0100113
Sadik Armagana2747482021-02-09 10:28:54 +0000114 // Iterates through input_data and operates over the reduced axis
narpra011e4c31d2018-09-28 11:07:51 +0100115 for (bool hasNext = true; hasNext; hasNext = NextIndex(inputNumDims, inputDims, tempIndex))
116 {
James Conroy4d1ff582019-06-10 17:06:39 +0100117 unsigned int inputOffset = ReducedOutputOffset(inputNumDims, inputDims, tempIndex, 0, {});
118 unsigned int outputOffset = ReducedOutputOffset(inputNumDims, inputDims, tempIndex,
119 numResolvedAxis, resolvedAxis);
120 input[inputOffset];
Sadik Armagana2747482021-02-09 10:28:54 +0000121 auto inputValue = input.Get();
122 if (reduceOperation == ReduceOperation::Max)
123 {
124 if (inputValue > tempOut[outputOffset])
125 {
126 tempOut[outputOffset] = inputValue;
127 }
128 }
129 else if (reduceOperation == ReduceOperation::Min)
130 {
131 if (inputValue < tempOut[outputOffset])
132 {
133 tempOut[outputOffset] = inputValue;
134 }
135 }
136 else
137 {
138 tempOut[outputOffset] += inputValue;
139 }
narpra011e4c31d2018-09-28 11:07:51 +0100140 }
141
Sadik Armagana2747482021-02-09 10:28:54 +0000142 // Takes average by num of elements added to get MEAN
narpra011e4c31d2018-09-28 11:07:51 +0100143 size_t numElementsInAxis = 1;
144 for (unsigned int idx = 0; idx < numResolvedAxis; ++idx)
145 {
James Conroy4d1ff582019-06-10 17:06:39 +0100146 unsigned int current = inputDims[resolvedAxis[idx]];
Matthew Sloyan24ac8592020-09-23 16:57:23 +0100147 ARMNN_ASSERT(armnn::numeric_cast<float>(current) <
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000148 (std::numeric_limits<float>::max() / armnn::numeric_cast<float>(numElementsInAxis)));
narpra011e4c31d2018-09-28 11:07:51 +0100149 numElementsInAxis *= current;
150 }
Sadik Armagana2747482021-02-09 10:28:54 +0000151
152 for (unsigned int idx = 0; idx < numOutputs; ++idx)
153 {
154 output[idx];
155 if (reduceOperation == ReduceOperation::Mean)
narpra011e4c31d2018-09-28 11:07:51 +0100156 {
Sadik Armagana2747482021-02-09 10:28:54 +0000157 if (numElementsInAxis > 0)
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000158 {
Sadik Armagana2747482021-02-09 10:28:54 +0000159 output.Set(tempOut[idx] / armnn::numeric_cast<float>(numElementsInAxis));
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000160 }
Sadik Armagana2747482021-02-09 10:28:54 +0000161 }
162 else
163 {
164 output.Set(tempOut[idx]);
narpra011e4c31d2018-09-28 11:07:51 +0100165 }
166 }
167}
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000168
169} //namespace armnn