blob: 8a084c7e848db9b2be977a05ae61b13feb384d99 [file] [log] [blame]
// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "comparison.h"
#include "arith_util.h"
#include "quant_util.h"
#include "template_types.h"
using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
template <int Rank, TOSA_REF_TYPE Dtype>
int OpEqual<Rank, Dtype>::register_fcn()
{
// Check Tosa Level
auto tosa_level = g_func_config.tosa_level;
LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
switch (Dtype)
{
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
case TOSA_REF_TYPE_INT32:
case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; };
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpGreater<Rank, Dtype>::register_fcn()
{
// Check Tosa Level
auto tosa_level = g_func_config.tosa_level;
LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
switch (Dtype)
{
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
case TOSA_REF_TYPE_INT32:
case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; };
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpGreaterEqual<Rank, Dtype>::register_fcn()
{
// Check Tosa Level
auto tosa_level = g_func_config.tosa_level;
LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
switch (Dtype)
{
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
case TOSA_REF_TYPE_INT32:
case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; };
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP64);