blob: cd9d55f293d8f6c196bb6c1daa9b9134c7bf5aae [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "reduction.h"
17#include "quant_util.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
23template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070024ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, TosaAttributeBase* attribute_, uint64_t id_)
25 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070026{
27 setRequiredOperands(1, 1);
28 setRequiredRank(0, 4);
29
30 INIT_ATTRIBUTE(Axis);
31}
32
33template <int Rank, DType Dtype>
34ReduceNode<Rank, Dtype>::~ReduceNode()
35{
36 if (attribute)
37 delete attribute;
38}
39
40template <int Rank, DType Dtype>
41int ReduceNode<Rank, Dtype>::checkTensorAttributes()
42{
43 if (validateRequiredOperands())
44 return 1;
45
46 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
47 {
48 return 1;
49 }
50
51 if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
52 {
Kevin Chengec5586c2021-10-06 14:37:37 -070053 printNodeValidationError("ReduceOp: axis must between [0, input_rank - 1]");
Eric Kunzee5e26762020-10-13 16:11:07 -070054 return 1;
55 }
56
Kevin Chengec5586c2021-10-06 14:37:37 -070057 if (inputs[0]->matchRankType(*outputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -070058 {
Kevin Chengec5586c2021-10-06 14:37:37 -070059 printNodeValidationError("ReduceOp: Input and output tensor ranks must match");
60 return 1;
61 }
62
63 if (outputs[0]->getShape()[attribute->axis()] != 1)
64 {
65 printNodeValidationError("ReduceOp: Output tensor shape[axis] needs to be 1.");
Eric Kunzee5e26762020-10-13 16:11:07 -070066 return 1;
67 }
68
69 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
70 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
71
Kevin Chengec5586c2021-10-06 14:37:37 -070072 if ((!in) || (!out))
73 {
74 printNodeValidationError("ReduceOp: Input or output fail to cast to Eigen tensor since rank/type not expected");
75 return 1;
76 }
Eric Kunzee5e26762020-10-13 16:11:07 -070077
78 dims[0] = this->attribute->axis();
79
80 return 0;
81}
82
James Ward24dbc422022-10-19 12:20:31 +010083// These 2 reducers are to overcome a bug introduced in Eigen between 3.3.7 and 3.4.0
84// The in-built .any and .all operations now fail on an assert in TensorMorphing.h:150
85// which seems to be due to incorrect data being passed internally as m_impl
86struct AllReducer {
87 static const bool PacketAccess = false;
88 void reduce(const bool val, bool* accum) {
89 *accum = *accum && val;
90 }
91 bool initialize() const { return true; }
92 bool finalize(const bool accum) const { return accum; }
93};
94struct AnyReducer {
95 static const bool PacketAccess = false;
96 void reduce(const bool val, bool* accum) {
97 *accum = *accum || val;
98 }
99 bool initialize() const { return false; }
100 bool finalize(const bool accum) const { return accum; }
101};
102
Eric Kunzee5e26762020-10-13 16:11:07 -0700103template <int Rank, DType Dtype>
104int OpReduceAll<Rank, Dtype>::eval()
105{
James Ward24dbc422022-10-19 12:20:31 +0100106 this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions());
Eric Kunzee5e26762020-10-13 16:11:07 -0700107
108 return GraphNode::eval();
109}
110
111template <int Rank, DType Dtype>
112int OpReduceAny<Rank, Dtype>::eval()
113{
James Ward24dbc422022-10-19 12:20:31 +0100114 this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions());
Eric Kunzee5e26762020-10-13 16:11:07 -0700115
116 return GraphNode::eval();
117}
118
119template <int Rank, DType Dtype>
120int OpReduceMax<Rank, Dtype>::eval()
121{
122 this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions());
123
124 return GraphNode::eval();
125}
126
127template <int Rank, DType Dtype>
128int OpReduceMin<Rank, Dtype>::eval()
129{
130 this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions());
131
132 return GraphNode::eval();
133}
134
135template <int Rank, DType Dtype>
136int OpReduceProduct<Rank, Dtype>::eval()
137{
James Ward24dbc422022-10-19 12:20:31 +0100138 switch(Dtype)
139 {
140 case DType_FP16:
141 case DType_BF16:
142 this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
143 break;
144 default:
145 this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
146 break;
147 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700148
149 return GraphNode::eval();
150}
151
152template <int Rank, DType Dtype>
153int OpReduceSum<Rank, Dtype>::eval()
154{
James Ward24dbc422022-10-19 12:20:31 +0100155 switch(Dtype)
156 {
157 case DType_FP16:
158 case DType_BF16:
159 this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
160 break;
161 default:
162 this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
163 break;
164 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700165
166 return GraphNode::eval();
167}
168
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100169struct SumRequiresReducer {
170 static const bool PacketAccess = false;
171 SumRequiresReducer(SubgraphTraverser* parent_sgt) : parent_sgt(parent_sgt) {}
172 void reduce(const int32_t val, int32_t* accum) {
173 int64_t res_in_64 = static_cast<int64_t>(*accum) + val;
174 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
175 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
176 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpReduceSum: result not in i32 range");
177 *accum = static_cast<int32_t>(res_in_64);
178 }
179 int32_t initialize() const { return 0; }
180 int32_t finalize(const int32_t accum) const { return accum; }
181
182 private:
183 SubgraphTraverser* parent_sgt;
184};
185
186template <int Rank, DType Dtype>
187int OpReduceSumInt<Rank, Dtype>::eval()
188{
189 this->out->getTensor() = this->in->getTensor().reduce(this->dims, SumRequiresReducer(this->parent_sgt)).reshape(this->out->getTensor().dimensions());
190
191 return GraphNode::eval();
192}
193
Eric Kunzee5e26762020-10-13 16:11:07 -0700194// template explicit instantiation
195DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
196
197DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
198
James Ward8b390432022-08-12 20:48:56 +0100199DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100200DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100201DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800202DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700203DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
204DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
205
James Ward8b390432022-08-12 20:48:56 +0100206DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100207DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100208DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800209DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700210DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
211DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
212
James Ward8b390432022-08-12 20:48:56 +0100213DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100214DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100215DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700216
James Ward8b390432022-08-12 20:48:56 +0100217DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100218DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100219DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100220DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);