IVGCVSW-4436 Add ExecuteNetwork test for mobilenet_v2_int8
* Add QAsymmS8 to QueueDescriptor supportedTypes
* Add QSymmS8/QAsymmS8 to RefLayerSupport supportedTypes
* Some additional comments and refactoring
Change-Id: I8567314452e6e8f6f69cb6e458ee147d3fc92fab
Signed-off-by: Keith Davis <keith.davis@arm.com>
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 560cdf1..593f3eb 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -301,7 +301,8 @@
}
}
-armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::vector<unsigned int>& shapes)
+armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::vector<unsigned int>& shapes,
+ const armnn::PermutationVector& dimensionMappings = {0, 1, 2, 3})
{
armnn::DataType type;
CHECK_TENSOR_PTR(tensorPtr);
@@ -317,10 +318,12 @@
case tflite::TensorType_INT8:
if (tensorPtr->quantization->zero_point.size() == 1 && tensorPtr->quantization->zero_point[0] != 0)
{
+ // Per-tensor
type = armnn::DataType::QAsymmS8;
}
else
{
+ // Per-channel
type = armnn::DataType::QSymmS8;
}
break;
@@ -388,12 +391,13 @@
tensorPtr->quantization->scale.end(),
std::back_inserter(quantizationScales));
- // QSymm Per-axis
+ // QSymmS8 Per-axis
armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()),
safeShape.data(),
type,
quantizationScales,
- boost::numeric_cast<unsigned int>(tensorPtr->quantization->quantized_dimension));
+ dimensionMappings[boost::numeric_cast<unsigned int>(
+ tensorPtr->quantization->quantized_dimension)]);
return result;
}
@@ -409,10 +413,11 @@
}
}
-armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr)
+armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr,
+ const armnn::PermutationVector& dimensionMappings = {0, 1, 2, 3})
{
auto const & dimensions = AsUnsignedVector(tensorPtr->shape);
- return ToTensorInfo(tensorPtr, dimensions);
+ return ToTensorInfo(tensorPtr, dimensions, dimensionMappings);
}
template<typename T>
@@ -905,8 +910,11 @@
desc.m_DilationX = CHECKED_NON_NEGATIVE(options->dilation_w_factor);
desc.m_DilationY = CHECKED_NON_NEGATIVE(options->dilation_h_factor);
+ // Mappings from TensorflowLite filter tensors to the ArmNN filter tensors (ArmNN weights have to be [M, I, H, W])
+ PermutationVector permutationVector{ 2, 3, 1, 0 }; // [H, W, I, M] -> [M, I, H, W]
+
armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
- armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1]);
+ armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1], permutationVector);
// Assuming input is NHWC
unsigned int inputHeight = inputTensorInfo.GetShape()[1];
@@ -922,9 +930,6 @@
inputTensorInfo.GetShape()[3],
filterTensorInfo.GetShape()[3] / inputTensorInfo.GetShape()[3] });
- // Mappings from TensorflowLite filter tensors to the ArmNN filter tensors (ArmNN weights have to be [M, I, H, W])
- PermutationVector permutationVector{ 2, 3, 1, 0 }; // [H, W, I, M] -> [M, I, H, W]
-
CalcPadding(inputHeight, filterHeight, desc.m_StrideY,
desc.m_DilationY, desc.m_PadTop, desc.m_PadBottom, options->padding);
CalcPadding(inputWidth, filterWidth, desc.m_StrideX,