blob: eccba090043159fb584cd804a4d3335d20b3a0c2 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "reduction.h"
17#include "quant_util.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
23template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070024ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, TosaAttributeBase* attribute_, uint64_t id_)
25 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070026{
27 setRequiredOperands(1, 1);
28 setRequiredRank(0, 4);
29
30 INIT_ATTRIBUTE(Axis);
31}
32
33template <int Rank, DType Dtype>
34ReduceNode<Rank, Dtype>::~ReduceNode()
35{
36 if (attribute)
37 delete attribute;
38}
39
40template <int Rank, DType Dtype>
41int ReduceNode<Rank, Dtype>::checkTensorAttributes()
42{
43 if (validateRequiredOperands())
44 return 1;
45
46 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
47 {
48 return 1;
49 }
50
51 if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
52 {
Kevin Chengec5586c2021-10-06 14:37:37 -070053 printNodeValidationError("ReduceOp: axis must between [0, input_rank - 1]");
Eric Kunzee5e26762020-10-13 16:11:07 -070054 return 1;
55 }
56
Kevin Chengec5586c2021-10-06 14:37:37 -070057 if (inputs[0]->matchRankType(*outputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -070058 {
Kevin Chengec5586c2021-10-06 14:37:37 -070059 printNodeValidationError("ReduceOp: Input and output tensor ranks must match");
60 return 1;
61 }
62
63 if (outputs[0]->getShape()[attribute->axis()] != 1)
64 {
65 printNodeValidationError("ReduceOp: Output tensor shape[axis] needs to be 1.");
Eric Kunzee5e26762020-10-13 16:11:07 -070066 return 1;
67 }
68
69 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
70 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
71
Kevin Chengec5586c2021-10-06 14:37:37 -070072 if ((!in) || (!out))
73 {
74 printNodeValidationError("ReduceOp: Input or output fail to cast to Eigen tensor since rank/type not expected");
75 return 1;
76 }
Eric Kunzee5e26762020-10-13 16:11:07 -070077
78 dims[0] = this->attribute->axis();
79
80 return 0;
81}
82
83template <int Rank, DType Dtype>
84int OpReduceAll<Rank, Dtype>::eval()
85{
86 this->out->getTensor() = this->in->getTensor().all(this->dims).reshape(this->out->getTensor().dimensions());
87
88 return GraphNode::eval();
89}
90
91template <int Rank, DType Dtype>
92int OpReduceAny<Rank, Dtype>::eval()
93{
94 this->out->getTensor() = this->in->getTensor().any(this->dims).reshape(this->out->getTensor().dimensions());
95
96 return GraphNode::eval();
97}
98
99template <int Rank, DType Dtype>
100int OpReduceMax<Rank, Dtype>::eval()
101{
102 this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions());
103
104 return GraphNode::eval();
105}
106
107template <int Rank, DType Dtype>
108int OpReduceMin<Rank, Dtype>::eval()
109{
110 this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions());
111
112 return GraphNode::eval();
113}
114
115template <int Rank, DType Dtype>
116int OpReduceProduct<Rank, Dtype>::eval()
117{
118 this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
119
120 return GraphNode::eval();
121}
122
123template <int Rank, DType Dtype>
124int OpReduceSum<Rank, Dtype>::eval()
125{
126 this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
127
128 return GraphNode::eval();
129}
130
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100131struct SumRequiresReducer {
132 static const bool PacketAccess = false;
133 SumRequiresReducer(SubgraphTraverser* parent_sgt) : parent_sgt(parent_sgt) {}
134 void reduce(const int32_t val, int32_t* accum) {
135 int64_t res_in_64 = static_cast<int64_t>(*accum) + val;
136 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
137 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
138 REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpReduceSum: result not in i32 range");
139 *accum = static_cast<int32_t>(res_in_64);
140 }
141 int32_t initialize() const { return 0; }
142 int32_t finalize(const int32_t accum) const { return accum; }
143
144 private:
145 SubgraphTraverser* parent_sgt;
146};
147
148template <int Rank, DType Dtype>
149int OpReduceSumInt<Rank, Dtype>::eval()
150{
151 this->out->getTensor() = this->in->getTensor().reduce(this->dims, SumRequiresReducer(this->parent_sgt)).reshape(this->out->getTensor().dimensions());
152
153 return GraphNode::eval();
154}
155
Eric Kunzee5e26762020-10-13 16:11:07 -0700156// template explicit instantiation
157DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
158
159DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
160
James Ward8b390432022-08-12 20:48:56 +0100161DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100162DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800163DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700164DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
165DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
166
James Ward8b390432022-08-12 20:48:56 +0100167DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100168DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800169DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700170DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
171DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
172
James Ward8b390432022-08-12 20:48:56 +0100173DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100174DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700175
James Ward8b390432022-08-12 20:48:56 +0100176DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100177DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
Jeremy Johnson7de9b452022-04-05 14:31:37 +0100178DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);