blob: 970c4151175bea998f338fe3a4f2a751da48e45e [file] [log] [blame]
//
// Copyright © 2017,2019-2021,2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "SerializerUtils.hpp"
namespace armnnSerializer
{
armnnSerializer::ComparisonOperation GetFlatBufferComparisonOperation(armnn::ComparisonOperation comparisonOperation)
{
switch (comparisonOperation)
{
case armnn::ComparisonOperation::Equal:
return armnnSerializer::ComparisonOperation::ComparisonOperation_Equal;
case armnn::ComparisonOperation::Greater:
return armnnSerializer::ComparisonOperation::ComparisonOperation_Greater;
case armnn::ComparisonOperation::GreaterOrEqual:
return armnnSerializer::ComparisonOperation::ComparisonOperation_GreaterOrEqual;
case armnn::ComparisonOperation::Less:
return armnnSerializer::ComparisonOperation::ComparisonOperation_Less;
case armnn::ComparisonOperation::LessOrEqual:
return armnnSerializer::ComparisonOperation::ComparisonOperation_LessOrEqual;
case armnn::ComparisonOperation::NotEqual:
default:
return armnnSerializer::ComparisonOperation::ComparisonOperation_NotEqual;
}
}
armnnSerializer::LogicalBinaryOperation GetFlatBufferLogicalBinaryOperation(
armnn::LogicalBinaryOperation logicalBinaryOperation)
{
switch (logicalBinaryOperation)
{
case armnn::LogicalBinaryOperation::LogicalAnd:
return armnnSerializer::LogicalBinaryOperation::LogicalBinaryOperation_LogicalAnd;
case armnn::LogicalBinaryOperation::LogicalOr:
return armnnSerializer::LogicalBinaryOperation::LogicalBinaryOperation_LogicalOr;
default:
throw armnn::InvalidArgumentException("Logical Binary operation unknown");
}
}
armnnSerializer::ConstTensorData GetFlatBufferConstTensorData(armnn::DataType dataType)
{
switch (dataType)
{
case armnn::DataType::Float32:
case armnn::DataType::Signed32:
return armnnSerializer::ConstTensorData::ConstTensorData_IntData;
case armnn::DataType::Float16:
case armnn::DataType::QSymmS16:
return armnnSerializer::ConstTensorData::ConstTensorData_ShortData;
case armnn::DataType::QAsymmS8:
case armnn::DataType::QAsymmU8:
case armnn::DataType::QSymmS8:
case armnn::DataType::Boolean:
return armnnSerializer::ConstTensorData::ConstTensorData_ByteData;
case armnn::DataType::Signed64:
return armnnSerializer::ConstTensorData::ConstTensorData_LongData;
default:
return armnnSerializer::ConstTensorData::ConstTensorData_NONE;
}
}
armnnSerializer::DataType GetFlatBufferDataType(armnn::DataType dataType)
{
switch (dataType)
{
case armnn::DataType::Float32:
return armnnSerializer::DataType::DataType_Float32;
case armnn::DataType::Float16:
return armnnSerializer::DataType::DataType_Float16;
case armnn::DataType::Signed32:
return armnnSerializer::DataType::DataType_Signed32;
case armnn::DataType::Signed64:
return armnnSerializer::DataType::DataType_Signed64;
case armnn::DataType::QSymmS16:
return armnnSerializer::DataType::DataType_QSymmS16;
case armnn::DataType::QAsymmS8:
return armnnSerializer::DataType::DataType_QAsymmS8;
case armnn::DataType::QAsymmU8:
return armnnSerializer::DataType::DataType_QAsymmU8;
case armnn::DataType::QSymmS8:
return armnnSerializer::DataType::DataType_QSymmS8;
case armnn::DataType::Boolean:
return armnnSerializer::DataType::DataType_Boolean;
default:
return armnnSerializer::DataType::DataType_Float16;
}
}
armnnSerializer::DataLayout GetFlatBufferDataLayout(armnn::DataLayout dataLayout)
{
switch (dataLayout)
{
case armnn::DataLayout::NHWC:
return armnnSerializer::DataLayout::DataLayout_NHWC;
case armnn::DataLayout::NDHWC:
return armnnSerializer::DataLayout::DataLayout_NDHWC;
case armnn::DataLayout::NCDHW:
return armnnSerializer::DataLayout::DataLayout_NCDHW;
case armnn::DataLayout::NCHW:
default:
return armnnSerializer::DataLayout::DataLayout_NCHW;
}
}
armnnSerializer::BinaryOperation GetFlatBufferBinaryOperation(armnn::BinaryOperation binaryOperation)
{
switch (binaryOperation)
{
case armnn::BinaryOperation::Add:
return armnnSerializer::BinaryOperation::BinaryOperation_Add;
case armnn::BinaryOperation::Div:
return armnnSerializer::BinaryOperation::BinaryOperation_Div;
case armnn::BinaryOperation::Maximum:
return armnnSerializer::BinaryOperation::BinaryOperation_Maximum;
case armnn::BinaryOperation::Minimum:
return armnnSerializer::BinaryOperation::BinaryOperation_Minimum;
case armnn::BinaryOperation::Mul:
return armnnSerializer::BinaryOperation::BinaryOperation_Mul;
case armnn::BinaryOperation::Sub:
return armnnSerializer::BinaryOperation::BinaryOperation_Sub;
case armnn::BinaryOperation::SqDiff:
return armnnSerializer::BinaryOperation::BinaryOperation_SqDiff;
case armnn::BinaryOperation::Power:
return armnnSerializer::BinaryOperation::BinaryOperation_Power;
default:
throw armnn::InvalidArgumentException("Elementwise Binary operation unknown");
}
}
armnnSerializer::UnaryOperation GetFlatBufferUnaryOperation(armnn::UnaryOperation unaryOperation)
{
switch (unaryOperation)
{
case armnn::UnaryOperation::Abs:
return armnnSerializer::UnaryOperation::UnaryOperation_Abs;
case armnn::UnaryOperation::Ceil:
return armnnSerializer::UnaryOperation::UnaryOperation_Ceil;
case armnn::UnaryOperation::Rsqrt:
return armnnSerializer::UnaryOperation::UnaryOperation_Rsqrt;
case armnn::UnaryOperation::Sqrt:
return armnnSerializer::UnaryOperation::UnaryOperation_Sqrt;
case armnn::UnaryOperation::Exp:
return armnnSerializer::UnaryOperation::UnaryOperation_Exp;
case armnn::UnaryOperation::Neg:
return armnnSerializer::UnaryOperation::UnaryOperation_Neg;
case armnn::UnaryOperation::LogicalNot:
return armnnSerializer::UnaryOperation::UnaryOperation_LogicalNot;
case armnn::UnaryOperation::Log:
return armnnSerializer::UnaryOperation::UnaryOperation_Log;
case armnn::UnaryOperation::Sin:
return armnnSerializer::UnaryOperation::UnaryOperation_Sin;
default:
throw armnn::InvalidArgumentException("Elementwise Unary operation unknown");
}
}
armnnSerializer::PoolingAlgorithm GetFlatBufferPoolingAlgorithm(armnn::PoolingAlgorithm poolingAlgorithm)
{
switch (poolingAlgorithm)
{
case armnn::PoolingAlgorithm::Average:
return armnnSerializer::PoolingAlgorithm::PoolingAlgorithm_Average;
case armnn::PoolingAlgorithm::L2:
return armnnSerializer::PoolingAlgorithm::PoolingAlgorithm_L2;
case armnn::PoolingAlgorithm::Max:
default:
return armnnSerializer::PoolingAlgorithm::PoolingAlgorithm_Max;
}
}
armnnSerializer::OutputShapeRounding GetFlatBufferOutputShapeRounding(armnn::OutputShapeRounding outputShapeRounding)
{
switch (outputShapeRounding)
{
case armnn::OutputShapeRounding::Ceiling:
return armnnSerializer::OutputShapeRounding::OutputShapeRounding_Ceiling;
case armnn::OutputShapeRounding::Floor:
default:
return armnnSerializer::OutputShapeRounding::OutputShapeRounding_Floor;
}
}
armnnSerializer::PaddingMethod GetFlatBufferPaddingMethod(armnn::PaddingMethod paddingMethod)
{
switch (paddingMethod)
{
case armnn::PaddingMethod::IgnoreValue:
return armnnSerializer::PaddingMethod::PaddingMethod_IgnoreValue;
case armnn::PaddingMethod::Exclude:
default:
return armnnSerializer::PaddingMethod::PaddingMethod_Exclude;
}
}
armnnSerializer::PaddingMode GetFlatBufferPaddingMode(armnn::PaddingMode paddingMode)
{
switch (paddingMode)
{
case armnn::PaddingMode::Reflect:
return armnnSerializer::PaddingMode::PaddingMode_Reflect;
case armnn::PaddingMode::Symmetric:
return armnnSerializer::PaddingMode::PaddingMode_Symmetric;
default:
return armnnSerializer::PaddingMode::PaddingMode_Constant;
}
}
armnnSerializer::NormalizationAlgorithmChannel GetFlatBufferNormalizationAlgorithmChannel(
armnn::NormalizationAlgorithmChannel normalizationAlgorithmChannel)
{
switch (normalizationAlgorithmChannel)
{
case armnn::NormalizationAlgorithmChannel::Across:
return armnnSerializer::NormalizationAlgorithmChannel::NormalizationAlgorithmChannel_Across;
case armnn::NormalizationAlgorithmChannel::Within:
return armnnSerializer::NormalizationAlgorithmChannel::NormalizationAlgorithmChannel_Within;
default:
return armnnSerializer::NormalizationAlgorithmChannel::NormalizationAlgorithmChannel_Across;
}
}
armnnSerializer::NormalizationAlgorithmMethod GetFlatBufferNormalizationAlgorithmMethod(
armnn::NormalizationAlgorithmMethod normalizationAlgorithmMethod)
{
switch (normalizationAlgorithmMethod)
{
case armnn::NormalizationAlgorithmMethod::LocalBrightness:
return armnnSerializer::NormalizationAlgorithmMethod::NormalizationAlgorithmMethod_LocalBrightness;
case armnn::NormalizationAlgorithmMethod::LocalContrast:
return armnnSerializer::NormalizationAlgorithmMethod::NormalizationAlgorithmMethod_LocalContrast;
default:
return armnnSerializer::NormalizationAlgorithmMethod::NormalizationAlgorithmMethod_LocalBrightness;
}
}
armnnSerializer::ResizeMethod GetFlatBufferResizeMethod(armnn::ResizeMethod method)
{
switch (method)
{
case armnn::ResizeMethod::NearestNeighbor:
return armnnSerializer::ResizeMethod_NearestNeighbor;
case armnn::ResizeMethod::Bilinear:
return armnnSerializer::ResizeMethod_Bilinear;
default:
return armnnSerializer::ResizeMethod_NearestNeighbor;
}
}
armnnSerializer::ReduceOperation GetFlatBufferReduceOperation(armnn::ReduceOperation reduceOperation)
{
switch (reduceOperation)
{
case armnn::ReduceOperation::Sum:
return armnnSerializer::ReduceOperation::ReduceOperation_Sum;
case armnn::ReduceOperation::Max:
return armnnSerializer::ReduceOperation::ReduceOperation_Max;
case armnn::ReduceOperation::Mean:
return armnnSerializer::ReduceOperation::ReduceOperation_Mean;
case armnn::ReduceOperation::Min:
return armnnSerializer::ReduceOperation::ReduceOperation_Min;
case armnn::ReduceOperation::Prod:
return armnnSerializer::ReduceOperation::ReduceOperation_Prod;
default:
return armnnSerializer::ReduceOperation::ReduceOperation_Sum;
}
}
} // namespace armnnSerializer