IVGCVSW-6175 Add Pooling3d to Neon
* Add IsSupported for Pooling3d
* Add CreateWorkload case for Pooling3d
* Create new NeonPooling3dWorkload header and source files
* Add Pooling3d workload to NeonWorkloads.hpp
* Add float32 tests for Pooling3d workload
* Add Uint8 tests for Cl and NE pooling3d
Signed-off-by: Ryan OShea <ryan.oshea3@arm.com>
Change-Id: Ic992e1233d1eb8db52df2c8446183df1c907bc4d
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp
index d50b253..cf541f4 100644
--- a/src/backends/neon/NeonLayerSupport.cpp
+++ b/src/backends/neon/NeonLayerSupport.cpp
@@ -58,6 +58,7 @@
#include "workloads/NeonPadWorkload.hpp"
#include "workloads/NeonPermuteWorkload.hpp"
#include "workloads/NeonPooling2dWorkload.hpp"
+#include "workloads/NeonPooling3dWorkload.hpp"
#include "workloads/NeonPreluWorkload.hpp"
#include "workloads/NeonQLstmWorkload.hpp"
#include "workloads/NeonQuantizeWorkload.hpp"
@@ -435,6 +436,11 @@
infos[1],
*(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
reasonIfUnsupported);
+ case LayerType::Pooling3d:
+ return IsPooling3dSupported(infos[0],
+ infos[1],
+ *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
+ reasonIfUnsupported);
case LayerType::Prelu:
return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
case LayerType::QLstm:
@@ -578,7 +584,7 @@
default:
// layers not supported in neon by default:
// debug, fakequantization, precompiled,
- // standin, switch, pooling3d
+ // standin, switch
return false;
}
}
@@ -1213,6 +1219,14 @@
FORWARD_WORKLOAD_VALIDATE_FUNC(NeonPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
}
+bool NeonLayerSupport::IsPooling3dSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const Pooling3dDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ FORWARD_WORKLOAD_VALIDATE_FUNC(NeonPooling3dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
+}
+
bool NeonLayerSupport::IsPreluSupported(const armnn::TensorInfo &input,
const armnn::TensorInfo &alpha,
const armnn::TensorInfo &output,