IVGCVSW-6119 ConstTensorsAsInput: FullyConnected
* Constant weights and biases are now stored as Constant layers.
* Updated Serializer, Deserializer and unit tests to reflect this.
* Updated TfLiteDelegate, TfLiteParser and OnnxParser.
* Updated Schema with IsConstant and ConstantTensorsAsInputs.
* Updated Ref backend to handle constant weights and
bias as inputs rather than reading from member variables.
* Added dynamic or constant input EndToEnd tests.
!android-nn-driver:5959
Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com>
Change-Id: Ibf3cf437df1100e4b322b0d303c575c6339f9696
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 1c18551..3f5972d 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -36,7 +36,11 @@
return info;
}
- return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
+ return TensorInfo(info.GetShape(),
+ type.value(),
+ info.GetQuantizationScale(),
+ info.GetQuantizationOffset(),
+ info.IsConstant());
}
} // anonymous namespace
@@ -364,16 +368,7 @@
TensorInfo weightsInfo;
const TensorInfo* weightsInfoPtr = nullptr;
- if (descriptor.m_ConstantWeights)
- {
- ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
- weightsInfo = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
- }
- else
- {
- weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
-
- }
+ weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
weightsInfoPtr = &weightsInfo;
TensorInfo biasInfo;
@@ -385,17 +380,8 @@
if (descriptor.m_BiasEnabled)
{
- if(descriptor.m_ConstantWeights)
- {
- ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
- biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
- biasInfoPtr = &biasInfo;
- }
- else
- {
- biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
- biasInfoPtr = &biasInfo;
- }
+ biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
+ biasInfoPtr = &biasInfo;
}
else
{