GitHub #667: Neon fold padding into average pool 2D quantization bug fix.
* Originated from a GitHub issue: https://github.com/ARM-software/armnn/issues/667
* Initially, Arm NN supports the pool 2D operation because there is no padding
on the pool2d. Neon failure occurs when padding is followed by average pool 2D
due to folding optimization.
* Here we prevent the folding optimization from happening for the above special case
and add it in as a backend specific optimization.
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: Ia0fd90c3a6b4b9d29c81106f154617d2e893e26b
diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp
index 1fe53de..d2e8fbf 100644
--- a/src/backends/cl/ClBackend.cpp
+++ b/src/backends/cl/ClBackend.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2022 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -566,6 +566,31 @@
untouched.erase(baseLayer->GetGuid());
}
}
+
+ // Special case to fuse padding into average pooling 2d for quantized datatype.
+ // Required to be done as a backend specific optimization as Neon does not support this special case.
+ if (base.GetType() == LayerType::Pooling2d)
+ {
+ Pooling2dLayer* baseLayer = PolymorphicDowncast<Pooling2dLayer*>(&base);
+ Pooling2dDescriptor poolingDescriptor = baseLayer->GetParameters();
+
+ if (baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer().GetType() == LayerType::Pad)
+ {
+ PadLayer* padLayer = PolymorphicDowncast<PadLayer*>(
+ &baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer());
+ if (padLayer->GetOutputSlot(0).GetNumConnections() == 1 &&
+ optimizations::pad_fold::TryFoldPadIntoLayer2d(padLayer->GetParameters(),
+ poolingDescriptor,
+ padLayer->GetOutputSlot().GetTensorInfo(),
+ true))
+ {
+ FoldPadIntoAveragePool2d<Pooling2dLayer>(optimizationViews, baseLayer,
+ poolingDescriptor, padLayer);
+ untouched.erase(baseLayer->GetGuid());
+ untouched.erase(padLayer->GetGuid());
+ }
+ }
+ }
}
if (optimizationViews.GetSubstitutions().empty())