blob: cd9d55f293d8f6c196bb6c1daa9b9134c7bf5aae [file] [log] [blame]
// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "reduction.h"
#include "quant_util.h"
using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
template <int Rank, DType Dtype>
ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, op_, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(0, 4);
INIT_ATTRIBUTE(Axis);
}
template <int Rank, DType Dtype>
ReduceNode<Rank, Dtype>::~ReduceNode()
{
if (attribute)
delete attribute;
}
template <int Rank, DType Dtype>
int ReduceNode<Rank, Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
{
return 1;
}
if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
{
printNodeValidationError("ReduceOp: axis must between [0, input_rank - 1]");
return 1;
}
if (inputs[0]->matchRankType(*outputs[0]))
{
printNodeValidationError("ReduceOp: Input and output tensor ranks must match");
return 1;
}
if (outputs[0]->getShape()[attribute->axis()] != 1)
{
printNodeValidationError("ReduceOp: Output tensor shape[axis] needs to be 1.");
return 1;
}
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
if ((!in) || (!out))
{
printNodeValidationError("ReduceOp: Input or output fail to cast to Eigen tensor since rank/type not expected");
return 1;
}
dims[0] = this->attribute->axis();
return 0;
}
// These 2 reducers are to overcome a bug introduced in Eigen between 3.3.7 and 3.4.0
// The in-built .any and .all operations now fail on an assert in TensorMorphing.h:150
// which seems to be due to incorrect data being passed internally as m_impl
struct AllReducer {
static const bool PacketAccess = false;
void reduce(const bool val, bool* accum) {
*accum = *accum && val;
}
bool initialize() const { return true; }
bool finalize(const bool accum) const { return accum; }
};
struct AnyReducer {
static const bool PacketAccess = false;
void reduce(const bool val, bool* accum) {
*accum = *accum || val;
}
bool initialize() const { return false; }
bool finalize(const bool accum) const { return accum; }
};
template <int Rank, DType Dtype>
int OpReduceAll<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions());
return GraphNode::eval();
}
template <int Rank, DType Dtype>
int OpReduceAny<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions());
return GraphNode::eval();
}
template <int Rank, DType Dtype>
int OpReduceMax<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions());
return GraphNode::eval();
}
template <int Rank, DType Dtype>
int OpReduceMin<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions());
return GraphNode::eval();
}
template <int Rank, DType Dtype>
int OpReduceProduct<Rank, Dtype>::eval()
{
switch(Dtype)
{
case DType_FP16:
case DType_BF16:
this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
break;
default:
this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
break;
}
return GraphNode::eval();
}
template <int Rank, DType Dtype>
int OpReduceSum<Rank, Dtype>::eval()
{
switch(Dtype)
{
case DType_FP16:
case DType_BF16:
this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
break;
default:
this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
break;
}
return GraphNode::eval();
}
struct SumRequiresReducer {
static const bool PacketAccess = false;
SumRequiresReducer(SubgraphTraverser* parent_sgt) : parent_sgt(parent_sgt) {}
void reduce(const int32_t val, int32_t* accum) {
int64_t res_in_64 = static_cast<int64_t>(*accum) + val;
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpReduceSum: result not in i32 range");
*accum = static_cast<int32_t>(res_in_64);
}
int32_t initialize() const { return 0; }
int32_t finalize(const int32_t accum) const { return accum; }
private:
SubgraphTraverser* parent_sgt;
};
template <int Rank, DType Dtype>
int OpReduceSumInt<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().reduce(this->dims, SumRequiresReducer(this->parent_sgt)).reshape(this->out->getTensor().dimensions());
return GraphNode::eval();
}
// template explicit instantiation
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);