IVGCVSW-5826 Change weights layout for depthwise to [1,H,W,I*M]
* This change is necessary because tflite uses a [1,H,W,I*M] format
and uses the I*M dimension for per axis quantization. Our previous
layout [M,I,H,W] can't handle the correlating quantization scales.
* Updates Onnx-, TfLiteParser and TfliteDelegate
* Updates the CpuRef, CpuAcc and GpuAcc backends
* Adjusts unit tests
* Adds test to ensure models with old layout can still be read and
executed
* Adds conversion function to previous layout [1,H,W,I*M] --> [M,I,H,W]
which can be used by backend developers
!android-nn-driver:5553
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Change-Id: Ifef23368b8c3702cf315a5838d214f7dc13c0152
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index a409715..1c9a1de 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -979,6 +979,7 @@
table FeatureCompatibilityVersions {
bindingIdsScheme:uint = 0;
+ weightsLayoutScheme:uint = 0;
}
// Root type for serialized data is the graph of the network
diff --git a/src/armnnSerializer/ArmnnSchema_generated.h b/src/armnnSerializer/ArmnnSchema_generated.h
index dfa4966..fc55d9b 100644
--- a/src/armnnSerializer/ArmnnSchema_generated.h
+++ b/src/armnnSerializer/ArmnnSchema_generated.h
@@ -9853,14 +9853,19 @@
struct FeatureCompatibilityVersions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef FeatureCompatibilityVersionsBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
- VT_BINDINGIDSSCHEME = 4
+ VT_BINDINGIDSSCHEME = 4,
+ VT_WEIGHTSLAYOUTSCHEME = 6
};
uint32_t bindingIdsScheme() const {
return GetField<uint32_t>(VT_BINDINGIDSSCHEME, 0);
}
+ uint32_t weightsLayoutScheme() const {
+ return GetField<uint32_t>(VT_WEIGHTSLAYOUTSCHEME, 0);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint32_t>(verifier, VT_BINDINGIDSSCHEME) &&
+ VerifyField<uint32_t>(verifier, VT_WEIGHTSLAYOUTSCHEME) &&
verifier.EndTable();
}
};
@@ -9872,6 +9877,9 @@
void add_bindingIdsScheme(uint32_t bindingIdsScheme) {
fbb_.AddElement<uint32_t>(FeatureCompatibilityVersions::VT_BINDINGIDSSCHEME, bindingIdsScheme, 0);
}
+ void add_weightsLayoutScheme(uint32_t weightsLayoutScheme) {
+ fbb_.AddElement<uint32_t>(FeatureCompatibilityVersions::VT_WEIGHTSLAYOUTSCHEME, weightsLayoutScheme, 0);
+ }
explicit FeatureCompatibilityVersionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -9886,8 +9894,10 @@
inline flatbuffers::Offset<FeatureCompatibilityVersions> CreateFeatureCompatibilityVersions(
flatbuffers::FlatBufferBuilder &_fbb,
- uint32_t bindingIdsScheme = 0) {
+ uint32_t bindingIdsScheme = 0,
+ uint32_t weightsLayoutScheme = 0) {
FeatureCompatibilityVersionsBuilder builder_(_fbb);
+ builder_.add_weightsLayoutScheme(weightsLayoutScheme);
builder_.add_bindingIdsScheme(bindingIdsScheme);
return builder_.Finish();
}
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 944797f..30a7e74 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -1787,7 +1787,8 @@
flatbuffers::Offset<armnnSerializer::FeatureCompatibilityVersions> versionsTable =
serializer::CreateFeatureCompatibilityVersions(
m_flatBufferBuilder,
- 1 // Binding ids scheme version
+ 1, // Binding ids scheme version
+ 1 // Weights layout scheme version
);
return versionsTable;
}