blob: f07ffd7ba9f26c9b3f06b586c8c9df14c839c048 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Tai Lya4d748b2023-03-28 22:06:56 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "reduction.h"
17#include "quant_util.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
Tai Lya4d748b2023-03-28 22:06:56 +000023template <int Rank, TOSA_REF_TYPE Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070024ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, TosaAttributeBase* attribute_, uint64_t id_)
25 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070026{
27 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +000028 setRequiredRank(1, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -070029
30 INIT_ATTRIBUTE(Axis);
31}
32
Tai Lya4d748b2023-03-28 22:06:56 +000033template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070034ReduceNode<Rank, Dtype>::~ReduceNode()
35{
36 if (attribute)
37 delete attribute;
38}
39
Tai Lya4d748b2023-03-28 22:06:56 +000040template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070041int ReduceNode<Rank, Dtype>::checkTensorAttributes()
42{
43 if (validateRequiredOperands())
44 return 1;
45
46 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
47 {
48 return 1;
49 }
50
51 if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
52 {
Kevin Chengec5586c2021-10-06 14:37:37 -070053 printNodeValidationError("ReduceOp: axis must between [0, input_rank - 1]");
Eric Kunzee5e26762020-10-13 16:11:07 -070054 return 1;
55 }
56
Kevin Chengec5586c2021-10-06 14:37:37 -070057 if (inputs[0]->matchRankType(*outputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -070058 {
Kevin Chengec5586c2021-10-06 14:37:37 -070059 printNodeValidationError("ReduceOp: Input and output tensor ranks must match");
60 return 1;
61 }
62
63 if (outputs[0]->getShape()[attribute->axis()] != 1)
64 {
65 printNodeValidationError("ReduceOp: Output tensor shape[axis] needs to be 1.");
Eric Kunzee5e26762020-10-13 16:11:07 -070066 return 1;
67 }
68
69 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
70 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
71
Kevin Chengec5586c2021-10-06 14:37:37 -070072 if ((!in) || (!out))
73 {
74 printNodeValidationError("ReduceOp: Input or output fail to cast to Eigen tensor since rank/type not expected");
75 return 1;
76 }
Eric Kunzee5e26762020-10-13 16:11:07 -070077
78 dims[0] = this->attribute->axis();
79
80 return 0;
81}
82
James Ward24dbc422022-10-19 12:20:31 +010083// These 2 reducers are to overcome a bug introduced in Eigen between 3.3.7 and 3.4.0
84// The in-built .any and .all operations now fail on an assert in TensorMorphing.h:150
85// which seems to be due to incorrect data being passed internally as m_impl
86struct AllReducer {
87 static const bool PacketAccess = false;
88 void reduce(const bool val, bool* accum) {
89 *accum = *accum && val;
90 }
91 bool initialize() const { return true; }
92 bool finalize(const bool accum) const { return accum; }
93};
94struct AnyReducer {
95 static const bool PacketAccess = false;
96 void reduce(const bool val, bool* accum) {
97 *accum = *accum || val;
98 }
99 bool initialize() const { return false; }
100 bool finalize(const bool accum) const { return accum; }
101};
102
Tai Lya4d748b2023-03-28 22:06:56 +0000103template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700104int OpReduceAll<Rank, Dtype>::eval()
105{
James Ward24dbc422022-10-19 12:20:31 +0100106 this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions());
Eric Kunzee5e26762020-10-13 16:11:07 -0700107
108 return GraphNode::eval();
109}
110
Tai Lya4d748b2023-03-28 22:06:56 +0000111template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700112int OpReduceAny<Rank, Dtype>::eval()
113{
James Ward24dbc422022-10-19 12:20:31 +0100114 this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions());
Eric Kunzee5e26762020-10-13 16:11:07 -0700115
116 return GraphNode::eval();
117}
118
Tai Lya4d748b2023-03-28 22:06:56 +0000119template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700120int OpReduceMax<Rank, Dtype>::eval()
121{
122 this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions());
123
124 return GraphNode::eval();
125}
126
Tai Lya4d748b2023-03-28 22:06:56 +0000127template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700128int OpReduceMin<Rank, Dtype>::eval()
129{
130 this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions());
131
132 return GraphNode::eval();
133}
134
Tai Lya4d748b2023-03-28 22:06:56 +0000135template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700136int OpReduceProduct<Rank, Dtype>::eval()
137{
James Ward24dbc422022-10-19 12:20:31 +0100138 switch(Dtype)
139 {
Tai Lya4d748b2023-03-28 22:06:56 +0000140 case TOSA_REF_TYPE_FP16:
141 case TOSA_REF_TYPE_BF16:
James Ward24dbc422022-10-19 12:20:31 +0100142 this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
143 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000144 case TOSA_REF_TYPE_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100145 this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
146 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000147 default:
148 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
James Ward24dbc422022-10-19 12:20:31 +0100149 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700150
151 return GraphNode::eval();
152}
153
Tai Lya4d748b2023-03-28 22:06:56 +0000154struct ProductDoubleReducer
155{
156 static const bool PacketAccess = false;
157 void reduce(const double val, double* accum)
158 {
159 *accum *= val;
160 }
161 double initialize() const
162 {
163 return 1.0;
164 }
165 double finalize(const double accum) const
166 {
167 return accum;
168 }
169};
170
171template <int Rank, TOSA_REF_TYPE Dtype>
172int OpReduceProductDouble<Rank, Dtype>::eval()
173{
174 switch (Dtype)
175 {
176 case TOSA_REF_TYPE_FP64:
177 this->out->getTensor() = this->in->getTensor()
178 .reduce(this->dims, ProductDoubleReducer())
179 .reshape(this->out->getTensor().dimensions());
180 break;
181 default:
182 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
183 }
184
185 return GraphNode::eval();
186}
187
188template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700189int OpReduceSum<Rank, Dtype>::eval()
190{
James Ward24dbc422022-10-19 12:20:31 +0100191 switch(Dtype)
192 {
Tai Lya4d748b2023-03-28 22:06:56 +0000193 case TOSA_REF_TYPE_FP16:
194 case TOSA_REF_TYPE_BF16:
Tai Ly307392a2023-05-12 21:42:19 +0000195 this->out->getTensor() = this->in->getTensor()
196 .sum(this->dims)
197 .reshape(this->out->getTensor().dimensions())
198 .unaryExpr([](float f) { return fpTrunc<Dtype>(f); });
James Ward24dbc422022-10-19 12:20:31 +0100199 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000200 case TOSA_REF_TYPE_FP32:
201 case TOSA_REF_TYPE_INT32:
James Ward24dbc422022-10-19 12:20:31 +0100202 this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
203 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000204 default:
205 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
James Ward24dbc422022-10-19 12:20:31 +0100206 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700207
208 return GraphNode::eval();
209}
210
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100211struct SumRequiresReducer {
212 static const bool PacketAccess = false;
213 SumRequiresReducer(SubgraphTraverser* parent_sgt) : parent_sgt(parent_sgt) {}
214 void reduce(const int32_t val, int32_t* accum) {
215 int64_t res_in_64 = static_cast<int64_t>(*accum) + val;
216 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
217 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
218 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpReduceSum: result not in i32 range");
219 *accum = static_cast<int32_t>(res_in_64);
220 }
221 int32_t initialize() const { return 0; }
222 int32_t finalize(const int32_t accum) const { return accum; }
223
224 private:
225 SubgraphTraverser* parent_sgt;
226};
227
Tai Lya4d748b2023-03-28 22:06:56 +0000228template <int Rank, TOSA_REF_TYPE Dtype>
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100229int OpReduceSumInt<Rank, Dtype>::eval()
230{
Tai Ly307392a2023-05-12 21:42:19 +0000231 this->out->getTensor() = this->in->getTensor()
232 .reduce(this->dims, SumRequiresReducer(this->parent_sgt))
233 .reshape(this->out->getTensor().dimensions());
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100234
235 return GraphNode::eval();
236}
237
Tai Lya4d748b2023-03-28 22:06:56 +0000238struct SumDoubleReducer
239{
240 static const bool PacketAccess = false;
241 void reduce(const double val, double* accum)
242 {
243 *accum += val;
244 }
245 double initialize() const
246 {
247 return 0.0;
248 }
249 double finalize(const double accum) const
250 {
251 return accum;
252 }
253};
254
255template <int Rank, TOSA_REF_TYPE Dtype>
256int OpReduceSumDouble<Rank, Dtype>::eval()
257{
Tai Ly307392a2023-05-12 21:42:19 +0000258 typename ReduceNode<Rank, Dtype>::TIn in_val = this->in->getTensor();
259 if (g_func_config.abs_mode)
260 {
261 // in abs_mode: take abs values of in value
262 in_val = in_val.abs();
263 }
Tai Lya4d748b2023-03-28 22:06:56 +0000264 switch (Dtype)
265 {
266 case TOSA_REF_TYPE_FP64:
Tai Ly307392a2023-05-12 21:42:19 +0000267 this->out->getTensor() =
268 in_val.reduce(this->dims, SumDoubleReducer()).reshape(this->out->getTensor().dimensions());
Tai Lya4d748b2023-03-28 22:06:56 +0000269 break;
270 default:
271 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
272 }
273
274 return GraphNode::eval();
275}
276
Eric Kunzee5e26762020-10-13 16:11:07 -0700277// template explicit instantiation
278DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
279
280DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
281
James Ward8b390432022-08-12 20:48:56 +0100282DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100283DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100284DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800285DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700286DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
287DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
Tai Lya4d748b2023-03-28 22:06:56 +0000288DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -0700289
James Ward8b390432022-08-12 20:48:56 +0100290DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100291DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100292DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800293DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700294DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
295DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
Tai Lya4d748b2023-03-28 22:06:56 +0000296DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -0700297
James Ward8b390432022-08-12 20:48:56 +0100298DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100299DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100300DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +0000301DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProductDouble, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -0700302
James Ward8b390432022-08-12 20:48:56 +0100303DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100304DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100305DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +0000306DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumDouble, FP64);
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100307DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);