blob: b57723ba6dc2c3cb8c699bc67eb9956c00cd9161 [file] [log] [blame]
# Copyright © 2020,2023 Arm Ltd. All rights reserved.
# SPDX-License-Identifier: MIT
import inspect
import sys
import logging
from ._generated.pyarmnn_version import GetVersion, GetMajorVersion, GetMinorVersion
print("WARNING: PyArmNN is deprecated. To use ArmNN on python, please use the tflite delegate. Expected to be removed in release 24.05.")
# Parsers
try:
from ._generated.pyarmnn_onnxparser import IOnnxParser
except ImportError as err:
logger = logging.getLogger(__name__)
message = "Your ArmNN library instance does not support Onnx models parser functionality. "
logger.warning("%s Skipped IOnnxParser import.", message)
logger.debug(str(err))
def IOnnxParser():
"""In case people try importing without having Arm NN built with this parser."""
raise RuntimeError(message)
try:
from ._generated.pyarmnn_tfliteparser import ITfLiteParser, TfLiteParserOptions
except ImportError as err:
logger = logging.getLogger(__name__)
message = "Your ArmNN library instance does not support TF lite models parser functionality. "
logger.warning("%s Skipped ITfLiteParser import.", message)
logger.debug(str(err))
def ITfLiteParser():
"""In case people try importing without having Arm NN built with this parser."""
raise RuntimeError(message)
try:
from ._generated.pyarmnn_deserializer import IDeserializer
except ImportError as err:
logger = logging.getLogger(__name__)
message = "Your ArmNN library instance does not have ArmNN model (.armnn) parser functionality. "
logger.warning("%s Skipped IDeserializer import.", message)
logger.debug(str(err))
def IDeserializer():
"""In case people try importing without having ArmNN built with this parser."""
raise RuntimeError(message)
# Network
from ._generated.pyarmnn import Optimize, OptimizerOptions, IOptimizedNetwork, IInputSlot, \
IOutputSlot, IConnectableLayer, INetwork
# Backend
from ._generated.pyarmnn import BackendId
from ._generated.pyarmnn import IDeviceSpec
from ._generated.pyarmnn import BackendOptions, BackendOption
# Tensors
from ._generated.pyarmnn import TensorInfo, TensorShape
# Runtime
from ._generated.pyarmnn import IRuntime, CreationOptions, INetworkProperties
# Profiler
from ._generated.pyarmnn import IProfiler
# Types
from ._generated.pyarmnn import DataType_Float16, DataType_Float32, DataType_QAsymmU8, DataType_Signed32, \
DataType_Boolean, DataType_QSymmS16, DataType_QSymmS8, DataType_QAsymmS8, ShapeInferenceMethod_ValidateOnly, \
ShapeInferenceMethod_InferAndValidate
from ._generated.pyarmnn import DataLayout_NCHW, DataLayout_NHWC, DataLayout_NCDHW, DataLayout_NDHWC
from ._generated.pyarmnn import MemorySource_Malloc, MemorySource_Undefined, MemorySource_DmaBuf, \
MemorySource_DmaBufProtected
from ._generated.pyarmnn import ProfilingDetailsMethod_Undefined, ProfilingDetailsMethod_DetailsWithEvents, \
ProfilingDetailsMethod_DetailsOnly
from ._generated.pyarmnn import ActivationFunction_Abs, ActivationFunction_BoundedReLu, ActivationFunction_LeakyReLu, \
ActivationFunction_Linear, ActivationFunction_ReLu, ActivationFunction_Sigmoid, ActivationFunction_SoftReLu, \
ActivationFunction_Sqrt, ActivationFunction_Square, ActivationFunction_TanH, ActivationDescriptor
from ._generated.pyarmnn import ArgMinMaxFunction_Max, ArgMinMaxFunction_Min, ArgMinMaxDescriptor
from ._generated.pyarmnn import BatchNormalizationDescriptor, BatchToSpaceNdDescriptor
from ._generated.pyarmnn import ChannelShuffleDescriptor, ComparisonDescriptor, ComparisonOperation_Equal, \
ComparisonOperation_Greater, ComparisonOperation_GreaterOrEqual, ComparisonOperation_Less, \
ComparisonOperation_LessOrEqual, ComparisonOperation_NotEqual
from ._generated.pyarmnn import UnaryOperation_Abs, UnaryOperation_Exp, UnaryOperation_Sqrt, UnaryOperation_Rsqrt, \
UnaryOperation_Neg, ElementwiseUnaryDescriptor
from ._generated.pyarmnn import LogicalBinaryOperation_LogicalAnd, LogicalBinaryOperation_LogicalOr, \
LogicalBinaryDescriptor
from ._generated.pyarmnn import Convolution2dDescriptor, Convolution3dDescriptor, DepthToSpaceDescriptor, \
DepthwiseConvolution2dDescriptor, DetectionPostProcessDescriptor, FakeQuantizationDescriptor, FillDescriptor, \
FullyConnectedDescriptor, GatherDescriptor, InstanceNormalizationDescriptor, LstmDescriptor, \
L2NormalizationDescriptor, MeanDescriptor
from ._generated.pyarmnn import NormalizationAlgorithmChannel_Across, NormalizationAlgorithmChannel_Within, \
NormalizationAlgorithmMethod_LocalBrightness, NormalizationAlgorithmMethod_LocalContrast, NormalizationDescriptor
from ._generated.pyarmnn import PaddingMode_Constant, PaddingMode_Reflect, PaddingMode_Symmetric, PadDescriptor
from ._generated.pyarmnn import PermutationVector, PermuteDescriptor
from ._generated.pyarmnn import OutputShapeRounding_Ceiling, OutputShapeRounding_Floor, \
PaddingMethod_Exclude, PaddingMethod_IgnoreValue, PoolingAlgorithm_Average, PoolingAlgorithm_L2, \
PoolingAlgorithm_Max, Pooling2dDescriptor, Pooling3dDescriptor
from ._generated.pyarmnn import ReduceDescriptor, ReduceOperation_Prod, ReduceOperation_Max, ReduceOperation_Mean, \
ReduceOperation_Min, ReduceOperation_Sum
from ._generated.pyarmnn import ResizeMethod_Bilinear, ResizeMethod_NearestNeighbor, ResizeDescriptor, \
ReshapeDescriptor, SliceDescriptor, SpaceToBatchNdDescriptor, SpaceToDepthDescriptor, StandInDescriptor, \
StackDescriptor, StridedSliceDescriptor, SoftmaxDescriptor, TransposeConvolution2dDescriptor, \
TransposeDescriptor, SplitterDescriptor
from ._generated.pyarmnn import ConcatDescriptor, CreateDescriptorForConcatenation
from ._generated.pyarmnn import LstmInputParams, QuantizedLstmInputParams
# Public API
# Quantization
from ._quantization.quantize_and_dequantize import quantize, dequantize
# Tensor
from ._tensor.tensor import Tensor
from ._tensor.const_tensor import ConstTensor
from ._tensor.workload_tensors import make_input_tensors, make_output_tensors, workload_tensors_to_ndarray
# Utilities
from ._utilities.profiling_helper import ProfilerData, get_profiling_data
from ._version import __version__, __arm_ml_version__
ARMNN_VERSION = GetVersion()
def __check_version():
from ._version import check_armnn_version
check_armnn_version(ARMNN_VERSION)
__check_version()
__all__ = []
__private_api_names = ['__check_version']
for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj) or inspect.isfunction(obj):
if name not in __private_api_names:
__all__.append(name)