blob: d7bddc0003c86c30ce3df82e33d7f15749c5fe94 [file] [log] [blame]
// Copyright (c) 2020, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ewise_unary.h"
#include "quant_util.h"
#include "template_types.h"
#include <cmath>
using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
template <int Rank, DType Dtype>
UnaryNode<Rank, Dtype>::UnaryNode(const Op& op_, uint64_t id_)
: GraphNode(op_, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(0, 6);
fcn = [](InEigenType a) -> OutEigenType { return OutEigenType(); };
}
template <int Rank, DType Dtype>
UnaryNode<Rank, Dtype>::~UnaryNode()
{}
template <int Rank, DType Dtype>
int UnaryNode<Rank, Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
{
return 1;
}
// output and input must be the same types
if (inputs[0]->matchRankSize(*outputs[0]))
{
printNodeValidationError("UnaryNode: input and output rank must match");
return 1;
}
a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
ASSERT_MEM(a && result);
return 0;
}
template <int Rank, DType Dtype>
int UnaryNode<Rank, Dtype>::eval()
{
this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn);
return GraphNode::eval();
}
template <int Rank, DType Dtype>
int OpAbs<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case DType_FLOAT:
case DType_INT32:
this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); };
break;
default:
FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
}
template <int Rank, DType Dtype>
int OpBitwiseNot<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case DType_AINT8:
case DType_INT16:
case DType_INT32:
this->fcn = [](InEigenType a) -> OutEigenType { return ~a; };
break;
default:
FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
}
template <int Rank, DType Dtype>
int OpCeil<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); };
break;
default:
FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
}
template <int Rank, DType Dtype>
int OpClz<Rank, Dtype>::register_fcn()
{
int32_t num_bits;
switch (Dtype)
{
case DType_INT32:
num_bits = 32;
break;
default:
FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
}
this->fcn = [num_bits](int32_t a) -> int32_t {
int32_t leading_zeros = 0;
for (int bit = num_bits - 1; bit >= 0; bit--)
{
if (((a >> bit) & 0x1) == 0)
{
leading_zeros++;
}
else
{
break;
}
}
return leading_zeros;
};
return 0;
}
template <int Rank, DType Dtype>
int OpExp<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); };
break;
default:
FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
}
template <int Rank, DType Dtype>
int OpFloor<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); };
break;
default:
FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
}
template <int Rank, DType Dtype>
int OpLog<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); };
break;
default:
FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
}
template <int Rank, DType Dtype>
int OpLogicalNot<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case DType_BOOL:
this->fcn = [](InEigenType a) -> OutEigenType { return !a; };
break;
default:
FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
}
template <int Rank, DType Dtype>
int OpNegate<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType {
InEigenType result = -(a);
return result;
};
break;
case DType_INT16:
case DType_INT32:
this->fcn = [](InEigenType a) -> OutEigenType {
InEigenType result = -(a);
return result;
};
break;
case DType_AINT8:
ASSERT(this->qinfo);
this->fcn = [this](InEigenType a) -> OutEigenType {
InEigenType result = -(a - this->qinfo->input_zp()) + this->qinfo->output_zp();
return result;
};
break;
default:
FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
}
template <int Rank, DType Dtype>
int OpReciprocal<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; };
break;
default:
FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
}
template <int Rank, DType Dtype>
int OpRsqrt<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); };
break;
default:
FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
}
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT);