blob: 2eb764ac9052221d229f2a89b53a83687c85e79f [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Tai Lya4d748b2023-03-28 22:06:56 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "reduction.h"
17#include "quant_util.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
Tai Lya4d748b2023-03-28 22:06:56 +000023template <int Rank, TOSA_REF_TYPE Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070024ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, TosaAttributeBase* attribute_, uint64_t id_)
25 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070026{
27 setRequiredOperands(1, 1);
Jerry Ge0bd4ec82023-05-01 18:36:43 +000028 setRequiredRank(1, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -070029
30 INIT_ATTRIBUTE(Axis);
31}
32
Tai Lya4d748b2023-03-28 22:06:56 +000033template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070034ReduceNode<Rank, Dtype>::~ReduceNode()
35{
36 if (attribute)
37 delete attribute;
38}
39
Tai Lya4d748b2023-03-28 22:06:56 +000040template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070041int ReduceNode<Rank, Dtype>::checkTensorAttributes()
42{
43 if (validateRequiredOperands())
44 return 1;
45
46 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
47 {
48 return 1;
49 }
50
51 if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
52 {
Kevin Chengec5586c2021-10-06 14:37:37 -070053 printNodeValidationError("ReduceOp: axis must between [0, input_rank - 1]");
Eric Kunzee5e26762020-10-13 16:11:07 -070054 return 1;
55 }
56
Kevin Chengec5586c2021-10-06 14:37:37 -070057 if (inputs[0]->matchRankType(*outputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -070058 {
Kevin Chengec5586c2021-10-06 14:37:37 -070059 printNodeValidationError("ReduceOp: Input and output tensor ranks must match");
60 return 1;
61 }
62
63 if (outputs[0]->getShape()[attribute->axis()] != 1)
64 {
65 printNodeValidationError("ReduceOp: Output tensor shape[axis] needs to be 1.");
Eric Kunzee5e26762020-10-13 16:11:07 -070066 return 1;
67 }
68
69 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
70 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
71
Kevin Chengec5586c2021-10-06 14:37:37 -070072 if ((!in) || (!out))
73 {
74 printNodeValidationError("ReduceOp: Input or output fail to cast to Eigen tensor since rank/type not expected");
75 return 1;
76 }
Eric Kunzee5e26762020-10-13 16:11:07 -070077
78 dims[0] = this->attribute->axis();
79
80 return 0;
81}
82
James Ward24dbc422022-10-19 12:20:31 +010083// These 2 reducers are to overcome a bug introduced in Eigen between 3.3.7 and 3.4.0
84// The in-built .any and .all operations now fail on an assert in TensorMorphing.h:150
85// which seems to be due to incorrect data being passed internally as m_impl
Jerry Ge9c9c8da2023-07-19 23:08:16 +000086struct AllReducer
87{
James Ward24dbc422022-10-19 12:20:31 +010088 static const bool PacketAccess = false;
Jerry Ge9c9c8da2023-07-19 23:08:16 +000089 void reduce(const bool val, bool* accum)
90 {
James Ward24dbc422022-10-19 12:20:31 +010091 *accum = *accum && val;
92 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +000093 bool initialize() const
94 {
95 return true;
96 }
97 bool finalize(const bool accum) const
98 {
99 return accum;
100 }
James Ward24dbc422022-10-19 12:20:31 +0100101};
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000102struct AnyReducer
103{
James Ward24dbc422022-10-19 12:20:31 +0100104 static const bool PacketAccess = false;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000105 void reduce(const bool val, bool* accum)
106 {
James Ward24dbc422022-10-19 12:20:31 +0100107 *accum = *accum || val;
108 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000109 bool initialize() const
110 {
111 return false;
112 }
113 bool finalize(const bool accum) const
114 {
115 return accum;
116 }
James Ward24dbc422022-10-19 12:20:31 +0100117};
118
Tai Lya4d748b2023-03-28 22:06:56 +0000119template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700120int OpReduceAll<Rank, Dtype>::eval()
121{
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000122 this->out->getTensor() =
123 this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions());
Eric Kunzee5e26762020-10-13 16:11:07 -0700124
125 return GraphNode::eval();
126}
127
Tai Lya4d748b2023-03-28 22:06:56 +0000128template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700129int OpReduceAny<Rank, Dtype>::eval()
130{
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000131 this->out->getTensor() =
132 this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions());
Eric Kunzee5e26762020-10-13 16:11:07 -0700133
134 return GraphNode::eval();
135}
136
Tai Lya4d748b2023-03-28 22:06:56 +0000137template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700138int OpReduceMax<Rank, Dtype>::eval()
139{
140 this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions());
141
142 return GraphNode::eval();
143}
144
Tai Lya4d748b2023-03-28 22:06:56 +0000145template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700146int OpReduceMin<Rank, Dtype>::eval()
147{
148 this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions());
149
150 return GraphNode::eval();
151}
152
Tai Lya4d748b2023-03-28 22:06:56 +0000153template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700154int OpReduceProduct<Rank, Dtype>::eval()
155{
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000156 switch (Dtype)
James Ward24dbc422022-10-19 12:20:31 +0100157 {
Tai Lya4d748b2023-03-28 22:06:56 +0000158 case TOSA_REF_TYPE_FP16:
159 case TOSA_REF_TYPE_BF16:
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000160 this->out->getTensor() = this->in->getTensor()
161 .prod(this->dims)
162 .reshape(this->out->getTensor().dimensions())
163 .unaryExpr([](float f) { return fpTrunc<Dtype>(f); });
James Ward24dbc422022-10-19 12:20:31 +0100164 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000165 case TOSA_REF_TYPE_FP32:
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000166 this->out->getTensor() =
167 this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
James Ward24dbc422022-10-19 12:20:31 +0100168 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000169 default:
170 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
James Ward24dbc422022-10-19 12:20:31 +0100171 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700172
173 return GraphNode::eval();
174}
175
Tai Lya4d748b2023-03-28 22:06:56 +0000176struct ProductDoubleReducer
177{
178 static const bool PacketAccess = false;
179 void reduce(const double val, double* accum)
180 {
181 *accum *= val;
182 }
183 double initialize() const
184 {
185 return 1.0;
186 }
187 double finalize(const double accum) const
188 {
189 return accum;
190 }
191};
192
193template <int Rank, TOSA_REF_TYPE Dtype>
194int OpReduceProductDouble<Rank, Dtype>::eval()
195{
196 switch (Dtype)
197 {
198 case TOSA_REF_TYPE_FP64:
199 this->out->getTensor() = this->in->getTensor()
200 .reduce(this->dims, ProductDoubleReducer())
201 .reshape(this->out->getTensor().dimensions());
202 break;
203 default:
204 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
205 }
206
207 return GraphNode::eval();
208}
209
210template <int Rank, TOSA_REF_TYPE Dtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700211int OpReduceSum<Rank, Dtype>::eval()
212{
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000213 switch (Dtype)
James Ward24dbc422022-10-19 12:20:31 +0100214 {
Tai Lya4d748b2023-03-28 22:06:56 +0000215 case TOSA_REF_TYPE_FP16:
216 case TOSA_REF_TYPE_BF16:
Tai Ly307392a2023-05-12 21:42:19 +0000217 this->out->getTensor() = this->in->getTensor()
218 .sum(this->dims)
219 .reshape(this->out->getTensor().dimensions())
220 .unaryExpr([](float f) { return fpTrunc<Dtype>(f); });
James Ward24dbc422022-10-19 12:20:31 +0100221 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000222 case TOSA_REF_TYPE_FP32:
223 case TOSA_REF_TYPE_INT32:
James Ward24dbc422022-10-19 12:20:31 +0100224 this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
225 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000226 default:
227 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
James Ward24dbc422022-10-19 12:20:31 +0100228 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
230 return GraphNode::eval();
231}
232
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000233struct SumRequiresReducer
234{
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100235 static const bool PacketAccess = false;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000236 SumRequiresReducer(SubgraphTraverser* parent_sgt)
237 : parent_sgt(parent_sgt)
238 {}
239 void reduce(const int32_t val, int32_t* accum)
240 {
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100241 int64_t res_in_64 = static_cast<int64_t>(*accum) + val;
242 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
243 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
244 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpReduceSum: result not in i32 range");
245 *accum = static_cast<int32_t>(res_in_64);
246 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000247 int32_t initialize() const
248 {
249 return 0;
250 }
251 int32_t finalize(const int32_t accum) const
252 {
253 return accum;
254 }
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100255
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000256private:
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100257 SubgraphTraverser* parent_sgt;
258};
259
Tai Lya4d748b2023-03-28 22:06:56 +0000260template <int Rank, TOSA_REF_TYPE Dtype>
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100261int OpReduceSumInt<Rank, Dtype>::eval()
262{
Tai Ly307392a2023-05-12 21:42:19 +0000263 this->out->getTensor() = this->in->getTensor()
264 .reduce(this->dims, SumRequiresReducer(this->parent_sgt))
265 .reshape(this->out->getTensor().dimensions());
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100266
267 return GraphNode::eval();
268}
269
Tai Lya4d748b2023-03-28 22:06:56 +0000270struct SumDoubleReducer
271{
272 static const bool PacketAccess = false;
273 void reduce(const double val, double* accum)
274 {
275 *accum += val;
276 }
277 double initialize() const
278 {
279 return 0.0;
280 }
281 double finalize(const double accum) const
282 {
283 return accum;
284 }
285};
286
287template <int Rank, TOSA_REF_TYPE Dtype>
288int OpReduceSumDouble<Rank, Dtype>::eval()
289{
Tai Ly307392a2023-05-12 21:42:19 +0000290 typename ReduceNode<Rank, Dtype>::TIn in_val = this->in->getTensor();
291 if (g_func_config.abs_mode)
292 {
293 // in abs_mode: take abs values of in value
294 in_val = in_val.abs();
295 }
Tai Lya4d748b2023-03-28 22:06:56 +0000296 switch (Dtype)
297 {
298 case TOSA_REF_TYPE_FP64:
Tai Ly307392a2023-05-12 21:42:19 +0000299 this->out->getTensor() =
300 in_val.reduce(this->dims, SumDoubleReducer()).reshape(this->out->getTensor().dimensions());
Tai Lya4d748b2023-03-28 22:06:56 +0000301 break;
302 default:
303 ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
304 }
305
306 return GraphNode::eval();
307}
308
Eric Kunzee5e26762020-10-13 16:11:07 -0700309// template explicit instantiation
310DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
311
312DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
313
James Ward8b390432022-08-12 20:48:56 +0100314DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100315DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100316DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800317DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700318DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
319DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
Tai Lya4d748b2023-03-28 22:06:56 +0000320DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -0700321
James Ward8b390432022-08-12 20:48:56 +0100322DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100323DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100324DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800325DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700326DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
327DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
Tai Lya4d748b2023-03-28 22:06:56 +0000328DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -0700329
James Ward8b390432022-08-12 20:48:56 +0100330DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100331DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100332DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +0000333DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProductDouble, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -0700334
James Ward8b390432022-08-12 20:48:56 +0100335DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100336DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100337DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
Tai Lya4d748b2023-03-28 22:06:56 +0000338DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumDouble, FP64);
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100339DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
Eric Kunzeedac6ab2023-06-28 13:29:38 -0700340
341DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(ReduceNode, BOOL);
342DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(ReduceNode, INT8);
343DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(ReduceNode, INT16);
344DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(ReduceNode, INT32);
345DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(ReduceNode, FP16);
346DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(ReduceNode, BF16);
347DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(ReduceNode, FP32);
348DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(ReduceNode, FP64);