IVGCVSW-4139 Fix regression in ConvertDequantize()
* Removed TENSOR_QUANT8_SYMM from the list of generally supported
tensor data types
* Fixed tensor info in DequantizeIfRequired() for on the fly
dequantized QSymm8 weights
* Moved code for checking whether a Dequantize operator is linked
to FullyConnected or Lstm weights from ConvertDequantize() into
a separate function inside 1.2/HalPolicy.cpp
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: I19ea6f89a90f553a964b87d44f8ad8a064e96f7f
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index f901a31..c8e2968 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -23,6 +23,63 @@
using namespace armnn;
+namespace
+{
+
+bool IsQSymmDequantizeForWeights(const Operation& operation, const Model& model)
+{
+ const Operand* operand = GetInputOperand<hal_1_2::HalPolicy>(operation, 0, model);
+ if (!operand)
+ {
+ return false;
+ }
+
+ if(!IsQSymm8(*operand))
+ {
+ // Only QSymm8 weights are dequantized on the fly by the driver
+ return false;
+ }
+
+ if (!IsOperandConstant<hal_1_2::HalPolicy>(*operand))
+ {
+ // Non-const input is not accepted for weights
+ return false;
+ }
+
+ // Iterate through all the operations and find the operation feeding from the Dequantize output
+ const size_t outputIndex = operation.outputs[0];
+ for (uint32_t operationIdx = 0; operationIdx < model.operations.size(); ++operationIdx)
+ {
+ const auto& operationIt = model.operations[operationIdx];
+ switch (operationIt.type)
+ {
+ case HalPolicy::OperationType::FULLY_CONNECTED:
+ if (outputIndex == operationIt.inputs[1]) // Weights are bound to slot 1
+ {
+ // If the output is going into the FC weights return true
+ return true;
+ }
+ break;
+ case HalPolicy::OperationType::LSTM:
+ for (size_t k = 0; k < operationIt.inputs.size(); ++k)
+ {
+ if (outputIndex == operationIt.inputs[k])
+ {
+ // If the output is going into the LSTM weights return true
+ return true;
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ }
+
+ return false;
+}
+
+} // anonymous namespace
+
bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, ConversionData& data)
{
switch (operation.type)
@@ -561,6 +618,14 @@
bool HalPolicy::ConvertDequantize(const Operation& operation, const Model& model, ConversionData& data)
{
ALOGV("hal_1_2::HalPolicy::ConvertDequantize()");
+
+ if (IsQSymmDequantizeForWeights(operation, model))
+ {
+ // NOTE: QSymm8 weights are dequantized internally by the driver,
+ // therefore this type of Dequantize is implicitly supported
+ return true;
+ }
+
return ::ConvertDequantize<hal_1_2::HalPolicy>(operation, model, data);
}