IVGCVSW-4259 Add frontend and reference workload for UnaryOperationLayer
* Added new layer named ElementwiseUnary
* Deprecated existing Abs/Rsqrt layer functions
* Updated existing Abs/Rsqrt test infrastructure to use new layer
* Added boilerplate for new Exp,Neg,Sqrt elemwise op layers
* AbsQuantize test removed pending future commit
* Serialization support added
!android-nn-driver:2550
Change-Id: Ic595c645925e17b45db568187fd05646daf2e87f
Signed-off-by: josh minor <josh.minor@arm.com>
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 26a61d4..491081d 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -70,28 +70,10 @@
bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- bool supported = true;
- std::array<DataType,4> supportedTypes =
- {
- DataType::Float32,
- DataType::Float16,
- DataType::QAsymmU8,
- DataType::QSymmS16
- };
-
- supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
- "Reference abs: input type not supported");
-
- supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
- "Reference abs: output type not supported");
-
- supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
- "Reference abs: input and output types not matching");
-
- supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
- "Reference abs: input and output shapes have different number of total elements");
-
- return supported;
+ return IsElementwiseUnarySupported(input,
+ output,
+ ElementwiseUnaryDescriptor(UnaryOperation::Abs),
+ reasonIfUnsupported);
}
bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
@@ -714,6 +696,39 @@
return supported;
}
+bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const ElementwiseUnaryDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ boost::ignore_unused(descriptor);
+
+ std::array<DataType, 4> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QAsymmU8,
+ DataType::QSymmS16
+ };
+
+ bool supported = true;
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference elementwise unary: input type not supported");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference elementwise unary: output type not supported");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference elementwise unary: input and output types not matching");
+
+ supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
+ "Reference elementwise unary: input and output shapes"
+ "have different number of total elements");
+
+ return supported;
+}
+
bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
@@ -1499,28 +1514,10 @@
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- bool supported = true;
- std::array<DataType,4> supportedTypes =
- {
- DataType::Float32,
- DataType::Float16,
- DataType::QAsymmU8,
- DataType::QSymmS16
- };
-
- supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
- "Reference rsqrt: input type not supported");
-
- supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
- "Reference rsqrt: output type not supported");
-
- supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
- "Reference rsqrt: input and output types not matching");
-
- supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
- "Reference Rsqrt: input and output shapes have different number of total elements");
-
- return supported;
+ return IsElementwiseUnarySupported(input,
+ output,
+ ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
+ reasonIfUnsupported);
}
bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,