blob: 6818f8c822e6bfa02e8e1f0cfafaea7f01b269c0 [file] [log] [blame]
// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ewise_unary.h"
#include "quant_util.h"
#include "template_types.h"
#include <cmath>
using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
template <int Rank, TOSA_REF_TYPE Dtype>
UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_)
: GraphNode(sgt_, op_, id_)
{
setRequiredOperands(1, 1);
fcn = [](InEigenType a) -> OutEigenType {
ASSERT_MSG(0, "In default UnaryNode function, missing function registration");
return OutEigenType();
};
}
template <int Rank, TOSA_REF_TYPE Dtype>
UnaryNode<Rank, Dtype>::~UnaryNode()
{}
template <int Rank, TOSA_REF_TYPE Dtype>
int UnaryNode<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
auto tosa_level = g_func_config.tosa_level;
LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
if (validateRequiredOperands())
return 1;
// output and input must be the same types
if (inputs[0]->matchRankTypeShape(*outputs[0]))
{
printNodeValidationError("UnaryNode: input and output rank/type/shape must match");
return 1;
}
a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
ASSERT_MEM(a && result);
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int UnaryNode<Rank, Dtype>::eval()
{
// call register_fcn() here to ensure inputs/outputs have been connected
// to the node by the time register_fcn() is called for Clamp Operator
if (register_fcn())
{
return 1;
}
this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn);
return GraphNode::eval();
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpAbs<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case TOSA_REF_TYPE_FP32: // No fpTrunc for FP32 as it is a no-op
case TOSA_REF_TYPE_FP64:
case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); };
break;
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(a > (InEigenType)0 ? a : (-a)); };
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpBitwiseNot<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case TOSA_REF_TYPE_INT8:
case TOSA_REF_TYPE_INT16:
case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a) -> OutEigenType { return ~a; };
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpCeil<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(ceilf(a)); };
break;
case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a) -> OutEigenType { return ceil(a); };
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpClz<Rank, Dtype>::register_fcn()
{
int32_t num_bits;
switch (Dtype)
{
case TOSA_REF_TYPE_INT32:
num_bits = 32;
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
this->fcn = [num_bits](int32_t a) -> int32_t {
int32_t leading_zeros = 0;
for (int bit = num_bits - 1; bit >= 0; bit--)
{
if (((a >> bit) & 0x1) == 0)
{
leading_zeros++;
}
else
{
break;
}
}
return leading_zeros;
};
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpCos<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(cos(a)); };
break;
case TOSA_REF_TYPE_FP64:
if (g_func_config.abs_mode)
{
// ABS_ERROR bounds return
this->fcn = [](InEigenType a) -> OutEigenType { return a; };
}
else
{
this->fcn = [](InEigenType a) -> OutEigenType { return cos(a); };
};
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpExp<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(expf(a)); };
break;
case TOSA_REF_TYPE_FP64:
if (g_func_config.abs_mode)
{
// ABS_ERROR bounds return (1+abs(a))
this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 + (a > (InEigenType)0 ? a : (-a)); };
}
else
{
this->fcn = [](InEigenType a) -> OutEigenType { return exp(a); };
}
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpFloor<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(floorf(a)); };
break;
case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a) -> OutEigenType { return floor(a); };
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpLog<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(logf(a)); };
break;
case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a) -> OutEigenType { return log(a); };
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpLogicalNot<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case TOSA_REF_TYPE_BOOL:
this->fcn = [](InEigenType a) -> OutEigenType { return !a; };
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
OpNegate<Rank, Dtype>::OpNegate(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: UnaryNode<Rank, Dtype>(sgt_, Op_NEGATE, id_)
{
INIT_ATTRIBUTE(Negate);
register_fcn();
}
template <int Rank, TOSA_REF_TYPE Dtype>
OpNegate<Rank, Dtype>::~OpNegate()
{
if (attribute)
delete attribute;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpNegate<Rank, Dtype>::register_fcn()
{
ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t");
ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t");
switch (Dtype)
{
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType {
InEigenType result = -(a);
return fpTrunc<Dtype>(result);
};
break;
case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a) -> OutEigenType {
OutEigenType result = -(a);
return result;
};
break;
case TOSA_REF_TYPE_INT16:
case TOSA_REF_TYPE_INT32:
this->fcn = [this](InEigenType a) -> OutEigenType {
int64_t res_in_64 = 0L - a;
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64,
"OpNegate: result not in acc type range (int32)");
int64_t max_clip_in_64, min_clip_in_64;
if (Dtype == TOSA_REF_TYPE_INT16)
{
max_clip_in_64 = static_cast<int64_t>(std::numeric_limits<int16_t>::max());
min_clip_in_64 = static_cast<int64_t>(std::numeric_limits<int16_t>::min());
}
else
{
max_clip_in_64 = i32_max_in_64;
min_clip_in_64 = i32_min_in_64;
}
return static_cast<InEigenType>(
std::min<int64_t>(max_clip_in_64, std::max<int64_t>(min_clip_in_64, res_in_64)));
};
break;
case TOSA_REF_TYPE_INT8:
this->fcn = [this](InEigenType a) -> OutEigenType {
int64_t res_in_64 = 0 - (a - attribute->input1_zp());
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64,
"OpNegate: result not in acc type range (int32)");
res_in_64 += attribute->output_zp();
InEigenType result = static_cast<InEigenType>(
std::min(std::max(res_in_64, static_cast<int64_t>(QMin)), static_cast<int64_t>(QMax)));
return result;
};
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpReciprocal<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / a); };
break;
case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a) -> OutEigenType { return (1.0L / a); };
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpRsqrt<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / sqrtf(a)); };
break;
case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a) -> OutEigenType { return (1.0L / sqrt(a)); };
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpSin<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(sin(a)); };
break;
case TOSA_REF_TYPE_FP64:
if (g_func_config.abs_mode)
{
// ABS_ERROR bounds return
this->fcn = [](InEigenType a) -> OutEigenType { return a; };
}
else
{
this->fcn = [](InEigenType a) -> OutEigenType { return sin(a); };
};
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP64);